From 1fb5dfae9d7cc769e3701422555333ee180c4d93 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 16 Mar 2026 16:31:57 -0700 Subject: [PATCH 1/8] Add Push Mode (Task Dispatchers and Pushers) --- Cargo.lock | 56 ++++++------ Cargo.toml | 10 +-- README.md | 2 +- src/config.rs | 20 +++++ src/dispatch.rs | 190 +++++++++++++++++++++++++++++++++++++++ src/grpc/server.rs | 8 ++ src/grpc/server_tests.rs | 42 ++++++--- src/lib.rs | 1 + src/main.rs | 52 ++++++++--- 9 files changed, 325 insertions(+), 56 deletions(-) create mode 100644 src/dispatch.rs diff --git a/Cargo.lock b/Cargo.lock index ec1f3551..d2a2ac9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1189,7 +1189,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2", "tokio", "tower-service", "tracing", @@ -1999,9 +1999,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -2009,9 +2009,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools", @@ -2022,9 +2022,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -2541,13 +2541,13 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae1eac4a748b11a2bb5b342bea8546085751cf9a45e30fb1276b072bb5541e6" +version = "0.8.4" +source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-taskbroker%2Fcreate-worker-service#f4cd3043b043c2f42e069104c3704177e1696504" dependencies = [ "prost", "prost-types", "tonic", + "tonic-prost", ] [[package]] @@ -2688,16 +2688,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.0" @@ -3166,7 +3156,7 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "slab", - "socket2 0.6.0", + "socket2", "tokio-macros", "windows-sys 0.59.0", ] @@ -3236,9 +3226,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", @@ -3253,8 +3243,8 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2 0.5.10", + "socket2", + "sync_wrapper", "tokio", "tokio-stream", "tower", @@ -3265,14 +3255,26 @@ dependencies = [ [[package]] name = "tonic-health" -version = "0.13.1" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb87334d340313fefa513b6e60794d44a86d5f039b523229c99c323e4e19ca4b" +checksum = "f4ff0636fef47afb3ec02818f5bceb4377b8abb9d6a386aeade18bd6212f8eb7" dependencies = [ "prost", "tokio", "tokio-stream", "tonic", + "tonic-prost", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 67119a08..6ed2b7b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,8 +26,8 @@ http-body-util = "0.1.2" libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" -prost = "0.13" -prost-types = "0.13.3" +prost = "0.14" +prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } sentry = { version = "0.41.0", default-features = false, features = [ @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = "0.4.11" +sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-taskbroker/create-worker-service" } serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" @@ -49,8 +49,8 @@ sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "po tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" -tonic = "0.13" -tonic-health = "0.13" +tonic = "0.14" +tonic-health = "0.14" tower = "0.5.1" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = [ diff --git a/README.md b/README.md index 860f446b..7a4f79bc 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ The test suite is composed of unit and integration tests in Rust, and end-to-end ```bash # Run unit/integration tests -make test +make unit-test # Run end-to-end tests make integration-test diff --git a/src/config.rs b/src/config.rs index 67b571d2..3f3bddb5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -239,6 +239,21 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, + + /// Run the taskbroker in push mode (as opposed to pull mode). + pub push_mode: bool, + + /// The number of concurrent dispatchers to run. + pub dispatchers: usize, + + /// The number of concurrent pushers each dispatcher should run. + pub pushers: usize, + + /// The size of the push queue. + pub push_queue_size: usize, + + /// The worker service endpoint. + pub worker_endpoint: String, } impl Default for Config { @@ -308,6 +323,11 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, + push_mode: false, + dispatchers: 1, + pushers: 1, + push_queue_size: 1, + worker_endpoint: "http://127.0.0.1:50052".into(), } } } diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 00000000..97f1bcd8 --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,190 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use prost::Message; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::time::sleep; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; + +/// This data structure fetches pending activations from the store and pushes them to the worker service. Each dispatcher has... +/// - One "fetch" loop that gets a pending activation from the store, sends it to a push channel, and repeats +/// - One or more "push" loops, each of which receives an activation from a channel, pushes that activation to a worker, and repeats +pub struct TaskDispatcher { + /// Sender for every push loop. + senders: Vec>, + + /// Receiver for every push loop. + receivers: Vec>, + + /// For every pending activation, increment and send to the channel with this index. + next_sender_idx: usize, + + /// Broker configuration. + config: Arc, + + /// Broker inflight activation store. + store: Arc, +} + +impl TaskDispatcher { + /// Create a new task dispatcher. + pub fn new(config: Arc, store: Arc) -> Self { + let n = config.pushers; + + let mut senders = Vec::with_capacity(n); + let mut receivers = Vec::with_capacity(n); + let next_sender_idx = 0; + + for _ in 0..n { + let (tx, rx) = mpsc::channel(config.push_queue_size); + senders.push(tx); + receivers.push(rx); + } + + Self { + senders, + receivers, + next_sender_idx, + config, + store, + } + } + + /// Initialize push loops and dispatcher loop. + pub async fn start(mut self) -> Result<()> { + let n = self.senders.len(); + info!("Starting {n} push loops..."); + + let endpoint = self.config.worker_endpoint.clone(); + let receivers = std::mem::take(&mut self.receivers); + + // Initialize each push loop + for mut rx in receivers.into_iter() { + let endpoint = endpoint.clone(); + + tokio::spawn(async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + return; + } + }; + + while let Some(activation) = rx.recv().await { + // Receive activation from the channel + let id = activation.id.clone(); + + // Try to push activation to the worker service + if let Err(e) = push_task(&mut worker, activation).await { + error!("Pushing activation {id} resulted in error - {:?}", e); + } else { + debug!("Activation {id} was sent to worker!"); + } + } + }); + } + + info!("Starting fetch loop..."); + let guard = get_shutdown_guard().shutdown_on_drop(); + + // Initialize the fetch loop + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + break; + } + + _ = async { + debug!("About to fetch next activation..."); + self.fetch_activation().await; + } => {} + } + } + + info!("Activation dispatcher shutting down..."); + Ok(()) + } + + /// Grab the next pending activation from the store, mark it as processing, and send to push channel. + async fn fetch_activation(&mut self) { + let start = Instant::now(); + metrics::counter!("pusher.fetch_activation.runs").increment(1); + + debug!("Fetching next pending activation..."); + + match self.store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + + let idx = self.next_sender_idx % self.senders.len(); + self.next_sender_idx = self.next_sender_idx.wrapping_add(1); + + if let Err(e) = self.senders[idx].send(activation).await { + error!("Failed to send activation {id} to worker - {:?}", e); + } + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + + Ok(_) => { + debug!("No pending activations, sleeping briefly..."); + sleep(milliseconds(100)).await; + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + + Err(e) => { + error!("Failed to fetch pending activations - {:?}", e); + sleep(milliseconds(100)).await; + + metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); + } + } + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut WorkerServiceClient, + activation: InflightActivation, +) -> Result<()> { + let start = Instant::now(); + let id = activation.id.clone(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { task: Some(task) }; + + let result = match worker.push_task(request).await { + Ok(_) => { + debug!("Successfully sent activation {id} to worker service!"); + Ok(()) + } + + Err(e) => { + error!("Could not push activation {id} to worker service - {:?}", e); + Err(e.into()) + } + }; + + metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + result +} + +#[inline] +fn milliseconds(i: u64) -> Duration { + Duration::from_millis(i) +} diff --git a/src/grpc/server.rs b/src/grpc/server.rs index f5ac9292..68d06699 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -9,11 +9,13 @@ use std::sync::Arc; use std::time::Instant; use tonic::{Request, Response, Status}; +use crate::config::Config; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; pub struct TaskbrokerServer { pub store: Arc, + pub config: Arc, } #[tonic::async_trait] @@ -23,6 +25,12 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { + if self.config.push_mode { + return Err(Status::permission_denied( + "Cannot call while broker is in PUSH mode", + )); + } + let start_time = Instant::now(); let application = &request.get_ref().application; let namespace = &request.get_ref().namespace; diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b99911a2..29e36f96 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -7,7 +7,7 @@ use sentry_protos::taskbroker::v1::{ }; use tonic::{Code, Request}; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::{create_config, create_test_store, make_activations}; #[tokio::test] #[rstest] @@ -15,7 +15,9 @@ use crate::test_utils::{create_test_store, make_activations}; #[case::postgres("postgres")] async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -34,7 +36,9 @@ async fn test_get_task(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -53,7 +57,9 @@ async fn test_set_task_status(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let service = TaskbrokerServer { store }; + let config = create_config(); + + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -76,10 +82,12 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: None, @@ -99,6 +107,8 @@ async fn test_get_task_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_application_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,7 +118,7 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), @@ -129,12 +139,14 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: Some(namespace), application: None, @@ -153,10 +165,12 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str #[allow(deprecated)] async fn test_set_task_status_success(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = GetTaskRequest { namespace: None, @@ -192,6 +206,8 @@ async fn test_set_task_status_success(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -201,7 +217,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -229,6 +245,8 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -238,7 +256,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -261,12 +279,14 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { #[allow(deprecated)] async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { let store = create_test_store(adapter).await; + let config = create_config(); + let activations = make_activations(2); let namespace = activations[0].namespace.clone(); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, config }; let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete diff --git a/src/lib.rs b/src/lib.rs index 33567944..6ff2b08d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ use clap::Parser; use std::fs; pub mod config; +pub mod dispatch; pub mod grpc; pub mod kafka; pub mod logging; diff --git a/src/main.rs b/src/main.rs index 7970939d..909fb594 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; +use taskbroker::dispatch::TaskDispatcher; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; @@ -39,16 +40,16 @@ use taskbroker::store::postgres_activation_store::{ use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; -async fn log_task_completion(name: &str, task: JoinHandle>) { +async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { Ok(Ok(())) => { - info!("Task {} completed", name); + info!("Task {} completed", name.as_ref()); } Ok(Err(e)) => { - error!("Task {} failed: {:?}", name, e); + error!("Task {} failed: {:?}", name.as_ref(), e); } Err(e) => { - error!("Task {} panicked: {:?}", name, e); + error!("Task {} panicked: {:?}", name.as_ref(), e); } } } @@ -190,22 +191,24 @@ async fn main() -> Result<(), Error> { // GRPC server let grpc_server_task = tokio::spawn({ - let grpc_store = store.clone(); - let grpc_config = config.clone(); + let store = store.clone(); + let config = config.clone(); + async move { - let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) + let addr = format!("{}:{}", config.grpc_addr, config.grpc_port) .parse() .expect("Failed to parse address"); let layers = tower::ServiceBuilder::new() .layer(MetricsLayer::default()) - .layer(AuthLayer::new(&grpc_config)) + .layer(AuthLayer::new(&config)) .into_inner(); let server = Server::builder() .layer(layers) .add_service(ConsumerServiceServer::new(TaskbrokerServer { - store: grpc_store, + store, + config, })) .add_service(health_service.clone()) .serve(addr); @@ -236,7 +239,27 @@ async fn main() -> Result<(), Error> { } }); - elegant_departure::tokio::depart() + // Activation dispatchers + let dispatchers = if config.push_mode { + info!("Running in PUSH mode"); + + (0..config.dispatchers) + .map(|_| { + let store = store.clone(); + let config = config.clone(); + + tokio::spawn(async move { + let dispatcher = TaskDispatcher::new(config, store); + dispatcher.start().await + }) + }) + .collect() + } else { + info!("Running in PULL mode"); + vec![] + }; + + let mut departure = elegant_departure::tokio::depart() .on_termination() .on_sigint() .on_signal(SignalKind::hangup()) @@ -244,8 +267,13 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("consumer", consumer_task)) .on_completion(log_task_completion("grpc_server", grpc_server_task)) .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)) - .await; + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + // Register each activation dispatch task + for (i, handle) in dispatchers.into_iter().enumerate() { + let task_name = format!("activation_dispatcher_{}", i); + departure = departure.on_completion(log_task_completion(task_name, handle)); + } Ok(()) } From a7286132f1ca96fd1196881d2ca11ca508a5d94a Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 16 Mar 2026 17:20:25 -0700 Subject: [PATCH 2/8] Add Unit Tests, Flush Tasks on Shutdown --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/dispatch.rs | 275 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 275 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2a2ac9a..6e56936d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2542,7 +2542,7 @@ dependencies = [ [[package]] name = "sentry_protos" version = "0.8.4" -source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-taskbroker%2Fcreate-worker-service#f4cd3043b043c2f42e069104c3704177e1696504" +source = "git+https://github.com/getsentry/sentry-protos#7873851032c697925dd7e532b6ad9888911f93b8" dependencies = [ "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 6ed2b7b9..1e82bc3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-taskbroker/create-worker-service" } +sentry_protos = { git = "https://github.com/getsentry/sentry-protos" } serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/src/dispatch.rs b/src/dispatch.rs index 97f1bcd8..4ed5dcf9 100644 --- a/src/dispatch.rs +++ b/src/dispatch.rs @@ -8,6 +8,7 @@ use anyhow::Result; use elegant_departure::get_shutdown_guard; use prost::Message; use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::task::JoinHandle; use tokio::time::sleep; use tonic::transport::Channel; use tracing::{debug, error, info}; @@ -59,6 +60,18 @@ impl TaskDispatcher { } } + /// Number of senders (and receivers) for testing purposes. + #[cfg(test)] + pub fn pusher_count(&self) -> usize { + self.senders.len() + } + + /// Take the receivers so a test can drain them. + #[cfg(test)] + pub fn take_receivers(&mut self) -> Vec> { + std::mem::take(&mut self.receivers) + } + /// Initialize push loops and dispatcher loop. pub async fn start(mut self) -> Result<()> { let n = self.senders.len(); @@ -67,11 +80,14 @@ impl TaskDispatcher { let endpoint = self.config.worker_endpoint.clone(); let receivers = std::mem::take(&mut self.receivers); + // Collect pusher handles so we can wait on them if shutdown is initiated + let mut handles: Vec> = Vec::with_capacity(receivers.len()); + // Initialize each push loop for mut rx in receivers.into_iter() { let endpoint = endpoint.clone(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { Ok(w) => w, @@ -93,6 +109,8 @@ impl TaskDispatcher { } } }); + + handles.push(handle); } info!("Starting fetch loop..."); @@ -114,11 +132,19 @@ impl TaskDispatcher { } info!("Activation dispatcher shutting down..."); + + // Close channels and drain any tasks still in the pushing pipeline + drop(std::mem::take(&mut self.senders)); + for handle in handles { + let _ = handle.await; + } + + info!("Activation dispatcher shut down."); Ok(()) } /// Grab the next pending activation from the store, mark it as processing, and send to push channel. - async fn fetch_activation(&mut self) { + pub async fn fetch_activation(&mut self) { let start = Instant::now(); metrics::counter!("pusher.fetch_activation.runs").increment(1); @@ -188,3 +214,248 @@ async fn push_task( fn milliseconds(i: u64) -> Duration { Duration::from_millis(i) } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use crate::config::Config; + use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivation, InflightActivationStatus, + InflightActivationStore, QueryResult, + }; + use crate::test_utils::{create_test_store, make_activations}; + + use anyhow::Error; + use async_trait::async_trait; + use chrono::{DateTime, Utc}; + + use super::TaskDispatcher; + + /// Mock store that returns activations from a queue for `get_pending_activation`. + struct MockStore { + activations: Mutex>, + } + + impl MockStore { + fn new(activations: Vec) -> Arc { + Arc::new(Self { + activations: Mutex::new(activations), + }) + } + } + + #[async_trait] + impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + limit: Option, + ) -> Result, Error> { + let limit = limit.unwrap_or(1) as usize; + let mut list = self.activations.lock().unwrap(); + let n = limit.min(list.len()); + + if n == 0 { + return Ok(vec![]); + } + + Ok(list.drain(..n).collect()) + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + Ok(self.activations.lock().unwrap().len()) + } + + async fn count(&self) -> Result { + Ok(self.activations.lock().unwrap().len()) + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched( + &self, + _killswitched_tasks: Vec, + ) -> Result { + unimplemented!() + } + } + + /// Asserts that a dispatcher built with X pushers has exactly X senders (and thus X receivers). + #[test] + fn pushers_x_creates_x_senders_and_receivers() { + // Use an empty mock store because we only care about construction, not fetching + let store: Arc = MockStore::new(vec![]); + + let config = Arc::new(Config { + pushers: 5, + push_queue_size: 10, + ..Config::default() + }); + + let dispatcher = TaskDispatcher::new(config, store); + + // One sender (and one receiver) per pusher + assert_eq!(dispatcher.pusher_count(), 5); + } + + /// Asserts that the fetch loop distributes activations round-robin across channels (0, 1, 2, 0, 1, 2, ...) + #[tokio::test] + async fn round_robin_sends_to_channels_0_1_2_0_1_2() { + // Six activations (id_0 .. id_5) so we get two full cycles across three channels + let activations = make_activations(6); + let store = MockStore::new(activations); + + let config = Arc::new(Config { + pushers: 3, + push_queue_size: 10, + ..Config::default() + }); + + let mut dispatcher = TaskDispatcher::new(config, store); + + // Take receivers so we can drain them - dispatcher keeps senders and will push to them + let mut receivers = dispatcher.take_receivers(); + assert_eq!(receivers.len(), 3); + + // Run the fetch loop six times - each run takes one activation from the mock and sends to next channel + for _ in 0..6 { + dispatcher.fetch_activation().await; + } + + // Receive in the same order the dispatcher sends - channel 0, then 1, then 2, then 0, 1, 2 + let mut received_by_channel: Vec> = vec![vec![], vec![], vec![]]; + for i in 0..6 { + let idx = i % 3; + let activation = receivers[idx].recv().await.expect("activation"); + received_by_channel[idx].push(activation.id.clone()); + } + + // Make sure round-robin works as intended... + // - Activations 1 and 4 go to channel 0 + // - Activations 2 and 5 go to channel 1 + // - Activations 3 and 6 go to channel 2 + assert_eq!(received_by_channel[0], &["id_0", "id_3"]); + assert_eq!(received_by_channel[1], &["id_1", "id_4"]); + assert_eq!(received_by_channel[2], &["id_2", "id_5"]); + } + + /// Asserts that after N fetch steps the store has zero pending activations (each fetch marks one as processing). + #[tokio::test] + async fn fetch_loop_drains_store() { + let activations = make_activations(3); + let store = create_test_store("sqlite").await; + + // Add activations to test store + store.store(activations).await.unwrap(); + assert_eq!(store.count_pending_activations().await.unwrap(), 3); + + let config = Arc::new(Config { + pushers: 2, + push_queue_size: 10, + ..Config::default() + }); + + let mut dispatcher = TaskDispatcher::new(config, store.clone()); + let mut receivers = dispatcher.take_receivers(); + + // Run fetch three times - each call gets one pending activation and moves it to processing + for _ in 0..3 { + dispatcher.fetch_activation().await; + } + + // Drain all activations from the channels so we've fully consumed what was fetched + let mut received = 0; + + for mut rx in receivers.drain(..) { + while rx.try_recv().is_ok() { + received += 1; + } + } + + // Have all activations been received? + assert_eq!(received, 3); + + // Real store marks as processing on `get_pending_activation` - so no pending left + assert_eq!(store.count_pending_activations().await.unwrap(), 0); + } +} From 6866ef20cb8d53874426f020c08b16b0a8b7e451 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 09:08:26 -0700 Subject: [PATCH 3/8] Switch to Sentry Protos Release --- Cargo.lock | 5 +++-- Cargo.toml | 2 +- src/dispatch.rs | 11 +++++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6e56936d..cf633443 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2541,8 +2541,9 @@ dependencies = [ [[package]] name = "sentry_protos" -version = "0.8.4" -source = "git+https://github.com/getsentry/sentry-protos#7873851032c697925dd7e532b6ad9888911f93b8" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d3c4e8bca4c556eec616dc2594e519248891ca84f8bf958016c2c416223d8ff" dependencies = [ "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 1e82bc3a..2d41f501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = { git = "https://github.com/getsentry/sentry-protos" } +sentry_protos = "0.8.5" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/src/dispatch.rs b/src/dispatch.rs index 4ed5dcf9..79258e4f 100644 --- a/src/dispatch.rs +++ b/src/dispatch.rs @@ -78,6 +78,8 @@ impl TaskDispatcher { info!("Starting {n} push loops..."); let endpoint = self.config.worker_endpoint.clone(); + let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let receivers = std::mem::take(&mut self.receivers); // Collect pusher handles so we can wait on them if shutdown is initiated @@ -86,6 +88,7 @@ impl TaskDispatcher { // Initialize each push loop for mut rx in receivers.into_iter() { let endpoint = endpoint.clone(); + let callback_url = callback_url.clone(); let handle = tokio::spawn(async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { @@ -102,7 +105,7 @@ impl TaskDispatcher { let id = activation.id.clone(); // Try to push activation to the worker service - if let Err(e) = push_task(&mut worker, activation).await { + if let Err(e) = push_task(&mut worker, activation, callback_url.clone()).await { error!("Pushing activation {id} resulted in error - {:?}", e); } else { debug!("Activation {id} was sent to worker!"); @@ -185,6 +188,7 @@ impl TaskDispatcher { async fn push_task( worker: &mut WorkerServiceClient, activation: InflightActivation, + callback_url: String, ) -> Result<()> { let start = Instant::now(); let id = activation.id.clone(); @@ -192,7 +196,10 @@ async fn push_task( // Try to decode activation (if it fails, we will see the error where `push_task` is called) let task = TaskActivation::decode(&activation.activation as &[u8])?; - let request = PushTaskRequest { task: Some(task) }; + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; let result = match worker.push_task(request).await { Ok(_) => { From 0a53d58971d7b666336069b3afbaa7e7f792f577 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 17:33:02 -0700 Subject: [PATCH 4/8] Replace Dispatcher w/Separate Fetch and Push Pools --- Cargo.lock | 20 ++- Cargo.toml | 1 + src/config.rs | 8 +- src/dispatch.rs | 468 ------------------------------------------------ src/fetch.rs | 111 ++++++++++++ src/lib.rs | 3 +- src/main.rs | 46 ++--- src/push.rs | 146 +++++++++++++++ 8 files changed, 307 insertions(+), 496 deletions(-) delete mode 100644 src/dispatch.rs create mode 100644 src/fetch.rs create mode 100644 src/push.rs diff --git a/Cargo.lock b/Cargo.lock index cf633443..9d26de75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,6 +741,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.16", +] [[package]] name = "figment" @@ -781,6 +784,18 @@ dependencies = [ "spin", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -940,8 +955,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -2893,7 +2910,7 @@ checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", "chrono", - "flume", + "flume 0.11.1", "futures-channel", "futures-core", "futures-executor", @@ -2983,6 +3000,7 @@ dependencies = [ "derive_builder", "elegant-departure", "figment", + "flume 0.12.0", "futures", "futures-util", "hex", diff --git a/Cargo.toml b/Cargo.toml index 2d41f501..a9758e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ clap = { version = "4.5.20", features = ["derive"] } derive_builder = "0.20.2" elegant-departure = { version = "0.3.1", features = ["tokio"] } figment = { version = "0.10.19", features = ["env", "yaml", "test"] } +flume = "0.12.0" futures = "0.3.31" futures-util = "0.3.31" hex = "0.4.3" diff --git a/src/config.rs b/src/config.rs index 3f3bddb5..e9c03dfe 100644 --- a/src/config.rs +++ b/src/config.rs @@ -244,10 +244,10 @@ pub struct Config { pub push_mode: bool, /// The number of concurrent dispatchers to run. - pub dispatchers: usize, + pub fetch_threads: usize, /// The number of concurrent pushers each dispatcher should run. - pub pushers: usize, + pub push_threads: usize, /// The size of the push queue. pub push_queue_size: usize, @@ -324,8 +324,8 @@ impl Default for Config { vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, push_mode: false, - dispatchers: 1, - pushers: 1, + fetch_threads: 1, + push_threads: 1, push_queue_size: 1, worker_endpoint: "http://127.0.0.1:50052".into(), } diff --git a/src/dispatch.rs b/src/dispatch.rs deleted file mode 100644 index 79258e4f..00000000 --- a/src/dispatch.rs +++ /dev/null @@ -1,468 +0,0 @@ -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; -use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; - -use anyhow::Result; -use elegant_departure::get_shutdown_guard; -use prost::Message; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio::task::JoinHandle; -use tokio::time::sleep; -use tonic::transport::Channel; -use tracing::{debug, error, info}; - -use crate::config::Config; -use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; - -/// This data structure fetches pending activations from the store and pushes them to the worker service. Each dispatcher has... -/// - One "fetch" loop that gets a pending activation from the store, sends it to a push channel, and repeats -/// - One or more "push" loops, each of which receives an activation from a channel, pushes that activation to a worker, and repeats -pub struct TaskDispatcher { - /// Sender for every push loop. - senders: Vec>, - - /// Receiver for every push loop. - receivers: Vec>, - - /// For every pending activation, increment and send to the channel with this index. - next_sender_idx: usize, - - /// Broker configuration. - config: Arc, - - /// Broker inflight activation store. - store: Arc, -} - -impl TaskDispatcher { - /// Create a new task dispatcher. - pub fn new(config: Arc, store: Arc) -> Self { - let n = config.pushers; - - let mut senders = Vec::with_capacity(n); - let mut receivers = Vec::with_capacity(n); - let next_sender_idx = 0; - - for _ in 0..n { - let (tx, rx) = mpsc::channel(config.push_queue_size); - senders.push(tx); - receivers.push(rx); - } - - Self { - senders, - receivers, - next_sender_idx, - config, - store, - } - } - - /// Number of senders (and receivers) for testing purposes. - #[cfg(test)] - pub fn pusher_count(&self) -> usize { - self.senders.len() - } - - /// Take the receivers so a test can drain them. - #[cfg(test)] - pub fn take_receivers(&mut self) -> Vec> { - std::mem::take(&mut self.receivers) - } - - /// Initialize push loops and dispatcher loop. - pub async fn start(mut self) -> Result<()> { - let n = self.senders.len(); - info!("Starting {n} push loops..."); - - let endpoint = self.config.worker_endpoint.clone(); - let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); - - let receivers = std::mem::take(&mut self.receivers); - - // Collect pusher handles so we can wait on them if shutdown is initiated - let mut handles: Vec> = Vec::with_capacity(receivers.len()); - - // Initialize each push loop - for mut rx in receivers.into_iter() { - let endpoint = endpoint.clone(); - let callback_url = callback_url.clone(); - - let handle = tokio::spawn(async move { - let mut worker = match WorkerServiceClient::connect(endpoint).await { - Ok(w) => w, - - Err(e) => { - error!("Failed to connect to worker - {:?}", e); - return; - } - }; - - while let Some(activation) = rx.recv().await { - // Receive activation from the channel - let id = activation.id.clone(); - - // Try to push activation to the worker service - if let Err(e) = push_task(&mut worker, activation, callback_url.clone()).await { - error!("Pushing activation {id} resulted in error - {:?}", e); - } else { - debug!("Activation {id} was sent to worker!"); - } - } - }); - - handles.push(handle); - } - - info!("Starting fetch loop..."); - let guard = get_shutdown_guard().shutdown_on_drop(); - - // Initialize the fetch loop - loop { - tokio::select! { - _ = guard.wait() => { - info!("Fetch loop received shutdown signal"); - break; - } - - _ = async { - debug!("About to fetch next activation..."); - self.fetch_activation().await; - } => {} - } - } - - info!("Activation dispatcher shutting down..."); - - // Close channels and drain any tasks still in the pushing pipeline - drop(std::mem::take(&mut self.senders)); - for handle in handles { - let _ = handle.await; - } - - info!("Activation dispatcher shut down."); - Ok(()) - } - - /// Grab the next pending activation from the store, mark it as processing, and send to push channel. - pub async fn fetch_activation(&mut self) { - let start = Instant::now(); - metrics::counter!("pusher.fetch_activation.runs").increment(1); - - debug!("Fetching next pending activation..."); - - match self.store.get_pending_activation(None, None).await { - Ok(Some(activation)) => { - let id = activation.id.clone(); - - let idx = self.next_sender_idx % self.senders.len(); - self.next_sender_idx = self.next_sender_idx.wrapping_add(1); - - if let Err(e) = self.senders[idx].send(activation).await { - error!("Failed to send activation {id} to worker - {:?}", e); - } - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - - Ok(_) => { - debug!("No pending activations, sleeping briefly..."); - sleep(milliseconds(100)).await; - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - - Err(e) => { - error!("Failed to fetch pending activations - {:?}", e); - sleep(milliseconds(100)).await; - - metrics::histogram!("pusher.fetch_activation.duration").record(start.elapsed()); - } - } - } -} - -/// Decode task activation and push it to a worker. -async fn push_task( - worker: &mut WorkerServiceClient, - activation: InflightActivation, - callback_url: String, -) -> Result<()> { - let start = Instant::now(); - let id = activation.id.clone(); - - // Try to decode activation (if it fails, we will see the error where `push_task` is called) - let task = TaskActivation::decode(&activation.activation as &[u8])?; - - let request = PushTaskRequest { - task: Some(task), - callback_url, - }; - - let result = match worker.push_task(request).await { - Ok(_) => { - debug!("Successfully sent activation {id} to worker service!"); - Ok(()) - } - - Err(e) => { - error!("Could not push activation {id} to worker service - {:?}", e); - Err(e.into()) - } - }; - - metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); - result -} - -#[inline] -fn milliseconds(i: u64) -> Duration { - Duration::from_millis(i) -} - -#[cfg(test)] -mod tests { - use std::sync::{Arc, Mutex}; - - use crate::config::Config; - use crate::store::inflight_activation::{ - FailedTasksForwarder, InflightActivation, InflightActivationStatus, - InflightActivationStore, QueryResult, - }; - use crate::test_utils::{create_test_store, make_activations}; - - use anyhow::Error; - use async_trait::async_trait; - use chrono::{DateTime, Utc}; - - use super::TaskDispatcher; - - /// Mock store that returns activations from a queue for `get_pending_activation`. - struct MockStore { - activations: Mutex>, - } - - impl MockStore { - fn new(activations: Vec) -> Arc { - Arc::new(Self { - activations: Mutex::new(activations), - }) - } - } - - #[async_trait] - impl InflightActivationStore for MockStore { - async fn vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn full_vacuum_db(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn db_size(&self) -> Result { - unimplemented!() - } - - async fn get_by_id(&self, _id: &str) -> Result, Error> { - unimplemented!() - } - - async fn store(&self, _batch: Vec) -> Result { - unimplemented!() - } - - async fn get_pending_activations_from_namespaces( - &self, - _application: Option<&str>, - _namespaces: Option<&[String]>, - limit: Option, - ) -> Result, Error> { - let limit = limit.unwrap_or(1) as usize; - let mut list = self.activations.lock().unwrap(); - let n = limit.min(list.len()); - - if n == 0 { - return Ok(vec![]); - } - - Ok(list.drain(..n).collect()) - } - - async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { - unimplemented!() - } - - async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { - Ok(self.activations.lock().unwrap().len()) - } - - async fn count(&self) -> Result { - Ok(self.activations.lock().unwrap().len()) - } - - async fn set_status( - &self, - _id: &str, - _status: InflightActivationStatus, - ) -> Result, Error> { - unimplemented!() - } - - async fn set_processing_deadline( - &self, - _id: &str, - _deadline: Option>, - ) -> Result<(), Error> { - unimplemented!() - } - - async fn delete_activation(&self, _id: &str) -> Result<(), Error> { - unimplemented!() - } - - async fn get_retry_activations(&self) -> Result, Error> { - unimplemented!() - } - - async fn clear(&self) -> Result<(), Error> { - unimplemented!() - } - - async fn handle_processing_deadline(&self) -> Result { - unimplemented!() - } - - async fn handle_processing_attempts(&self) -> Result { - unimplemented!() - } - - async fn handle_expires_at(&self) -> Result { - unimplemented!() - } - - async fn handle_delay_until(&self) -> Result { - unimplemented!() - } - - async fn handle_failed_tasks(&self) -> Result { - unimplemented!() - } - - async fn mark_completed(&self, _ids: Vec) -> Result { - unimplemented!() - } - - async fn remove_completed(&self) -> Result { - unimplemented!() - } - - async fn remove_killswitched( - &self, - _killswitched_tasks: Vec, - ) -> Result { - unimplemented!() - } - } - - /// Asserts that a dispatcher built with X pushers has exactly X senders (and thus X receivers). - #[test] - fn pushers_x_creates_x_senders_and_receivers() { - // Use an empty mock store because we only care about construction, not fetching - let store: Arc = MockStore::new(vec![]); - - let config = Arc::new(Config { - pushers: 5, - push_queue_size: 10, - ..Config::default() - }); - - let dispatcher = TaskDispatcher::new(config, store); - - // One sender (and one receiver) per pusher - assert_eq!(dispatcher.pusher_count(), 5); - } - - /// Asserts that the fetch loop distributes activations round-robin across channels (0, 1, 2, 0, 1, 2, ...) - #[tokio::test] - async fn round_robin_sends_to_channels_0_1_2_0_1_2() { - // Six activations (id_0 .. id_5) so we get two full cycles across three channels - let activations = make_activations(6); - let store = MockStore::new(activations); - - let config = Arc::new(Config { - pushers: 3, - push_queue_size: 10, - ..Config::default() - }); - - let mut dispatcher = TaskDispatcher::new(config, store); - - // Take receivers so we can drain them - dispatcher keeps senders and will push to them - let mut receivers = dispatcher.take_receivers(); - assert_eq!(receivers.len(), 3); - - // Run the fetch loop six times - each run takes one activation from the mock and sends to next channel - for _ in 0..6 { - dispatcher.fetch_activation().await; - } - - // Receive in the same order the dispatcher sends - channel 0, then 1, then 2, then 0, 1, 2 - let mut received_by_channel: Vec> = vec![vec![], vec![], vec![]]; - for i in 0..6 { - let idx = i % 3; - let activation = receivers[idx].recv().await.expect("activation"); - received_by_channel[idx].push(activation.id.clone()); - } - - // Make sure round-robin works as intended... - // - Activations 1 and 4 go to channel 0 - // - Activations 2 and 5 go to channel 1 - // - Activations 3 and 6 go to channel 2 - assert_eq!(received_by_channel[0], &["id_0", "id_3"]); - assert_eq!(received_by_channel[1], &["id_1", "id_4"]); - assert_eq!(received_by_channel[2], &["id_2", "id_5"]); - } - - /// Asserts that after N fetch steps the store has zero pending activations (each fetch marks one as processing). - #[tokio::test] - async fn fetch_loop_drains_store() { - let activations = make_activations(3); - let store = create_test_store("sqlite").await; - - // Add activations to test store - store.store(activations).await.unwrap(); - assert_eq!(store.count_pending_activations().await.unwrap(), 3); - - let config = Arc::new(Config { - pushers: 2, - push_queue_size: 10, - ..Config::default() - }); - - let mut dispatcher = TaskDispatcher::new(config, store.clone()); - let mut receivers = dispatcher.take_receivers(); - - // Run fetch three times - each call gets one pending activation and moves it to processing - for _ in 0..3 { - dispatcher.fetch_activation().await; - } - - // Drain all activations from the channels so we've fully consumed what was fetched - let mut received = 0; - - for mut rx in receivers.drain(..) { - while rx.try_recv().is_ok() { - received += 1; - } - } - - // Have all activations been received? - assert_eq!(received, 3); - - // Real store marks as processing on `get_pending_activation` - so no pending left - assert_eq!(store.count_pending_activations().await.unwrap(), 0); - } -} diff --git a/src/fetch.rs b/src/fetch.rs new file mode 100644 index 00000000..10cfbbf4 --- /dev/null +++ b/src/fetch.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use tokio::time::sleep; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivationStore; + +/// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. +pub struct FetchPool { + /// Inflight activation store. + store: Arc, + + /// Pool of push threads that push activations to the worker service. + push_pool: Arc, + + /// Taskbroker configuration. + config: Arc, +} + +impl FetchPool { + /// Initialize a new fetch pool. + pub fn new( + store: Arc, + config: Arc, + push_pool: Arc, + ) -> Self { + Self { + store, + push_pool, + config, + } + } + + /// Spawn `config.fetch_threads` asynchronous tasks, each of which repeatedly moves pending activations from the store to the push pool until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.fetch_threads.max(1) { + let guard = get_shutdown_guard().shutdown_on_drop(); + + let store = self.store.clone(); + let push_pool = self.push_pool.clone(); + + let handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = guard.wait() => { + info!("Fetch loop received shutdown signal"); + break; + } + + _ = async { + debug!("About to fetch next activation..."); + fetch_activations(store.clone(), push_pool.clone()).await; + } => {} + } + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } +} + +/// Grab the next pending activation from the store, mark it as processing, and send to push channel. +pub async fn fetch_activations(store: Arc, push_pool: Arc) { + let start = Instant::now(); + metrics::counter!("fetch.fetch_activations.runs").increment(1); + + debug!("Fetching next pending activation..."); + + match store.get_pending_activation(None, None).await { + Ok(Some(activation)) => { + let id = activation.id.clone(); + debug!("Atomically fetched and marked task {id} as processing"); + + if let Err(e) = push_pool.submit(activation).await { + error!("Failed to submit task {id} to push pool - {:?}", e); + } + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Ok(_) => { + debug!("No pending activations, sleeping briefly..."); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + + Err(e) => { + error!("Failed to fetch pending activation - {:?}", e); + sleep(Duration::from_millis(100)).await; + + metrics::histogram!("fetch.fetch_activations.duration").record(start.elapsed()); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6ff2b08d..baf480d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,11 +2,12 @@ use clap::Parser; use std::fs; pub mod config; -pub mod dispatch; +pub mod fetch; pub mod grpc; pub mod kafka; pub mod logging; pub mod metrics; +pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; diff --git a/src/main.rs b/src/main.rs index 909fb594..4efdbc58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,11 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use std::{sync::Arc, time::Duration}; -use taskbroker::dispatch::TaskDispatcher; +use taskbroker::fetch::FetchPool; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; +use taskbroker::push::PushPool; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; @@ -239,24 +240,22 @@ async fn main() -> Result<(), Error> { } }); - // Activation dispatchers - let dispatchers = if config.push_mode { - info!("Running in PUSH mode"); - - (0..config.dispatchers) - .map(|_| { - let store = store.clone(); - let config = config.clone(); - - tokio::spawn(async move { - let dispatcher = TaskDispatcher::new(config, store); - dispatcher.start().await - }) - }) - .collect() + // Initialize push and fetch pools + let push_pool = Arc::new(PushPool::new(config.clone())); + let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); + + // Initialize push threads + let push_task = if config.push_mode { + Some(tokio::spawn(async move { push_pool.start().await })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.push_mode { + Some(tokio::spawn(async move { fetch_pool.start().await })) } else { - info!("Running in PULL mode"); - vec![] + None }; let mut departure = elegant_departure::tokio::depart() @@ -269,11 +268,14 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("upkeep_task", upkeep_task)) .on_completion(log_task_completion("maintenance_task", maintenance_task)); - // Register each activation dispatch task - for (i, handle) in dispatchers.into_iter().enumerate() { - let task_name = format!("activation_dispatcher_{}", i); - departure = departure.on_completion(log_task_completion(task_name, handle)); + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); } + departure.await; Ok(()) } diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 00000000..6ffa23bc --- /dev/null +++ b/src/push.rs @@ -0,0 +1,146 @@ +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use flume::{Receiver, Sender}; +use prost::Message; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tonic::transport::Channel; +use tracing::{debug, error, info}; + +use crate::config::Config; +use crate::store::inflight_activation::InflightActivation; + +/// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. +pub struct PushPool { + /// The sending end of a channel that accepts task activations. + sender: Sender, + + /// The receiving end of a channel that accepts task activations. + receiver: Receiver, + + /// Taskbroker configuration. + config: Arc, +} + +impl PushPool { + /// Initialize a new push pool. + pub fn new(config: Arc) -> Self { + let (sender, receiver) = flume::bounded(config.push_queue_size); + + Self { + sender, + receiver, + config, + } + } + + /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. + pub async fn start(&self) -> Result<()> { + let mut handles = vec![]; + + for _ in 0..self.config.push_threads { + let endpoint = self.config.worker_endpoint.clone(); + + let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let receiver = self.receiver.clone(); + let guard = get_shutdown_guard().shutdown_on_drop(); + + let handle = tokio::spawn(async move { + let mut worker = match WorkerServiceClient::connect(endpoint).await { + Ok(w) => w, + Err(e) => { + error!("Failed to connect to worker - {:?}", e); + return; + } + }; + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push worker received shutdown signal"); + break; + } + + message = receiver.recv_async() => { + let activation = match message { + // Received activation from fetch thread + Ok(a) => a, + + // Channel closed + Err(_) => break + }; + + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e) + }; + } + } + } + + // Drain channel before exiting + for activation in receiver.drain() { + let id = activation.id.clone(); + + match push_task(&mut worker, activation, callback_url.clone()).await { + Ok(_) => debug!("Activation {id} was sent to worker!"), + Err(e) => error!("Pushing activation {id} resulted in error - {:?}", e), + }; + } + }); + + handles.push(handle); + } + + for handle in handles { + if let Err(e) = handle.await { + return Err(e.into()); + } + } + + Ok(()) + } + + /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. + pub async fn submit(&self, activation: InflightActivation) -> Result<()> { + Ok(self.sender.send_async(activation).await?) + } +} + +/// Decode task activation and push it to a worker. +async fn push_task( + worker: &mut WorkerServiceClient, + activation: InflightActivation, + callback_url: String, +) -> Result<()> { + let start = Instant::now(); + let id = activation.id.clone(); + + // Try to decode activation (if it fails, we will see the error where `push_task` is called) + let task = TaskActivation::decode(&activation.activation as &[u8])?; + + let request = PushTaskRequest { + task: Some(task), + callback_url, + }; + + let result = match worker.push_task(request).await { + Ok(_) => { + debug!("Successfully sent activation {id} to worker service!"); + Ok(()) + } + + Err(e) => { + error!("Could not push activation {id} to worker service - {:?}", e); + Err(e.into()) + } + }; + + metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + result +} From 043559ad3b320d7c5e2ce9ecc477f8286cd3e4c2 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 17:36:19 -0700 Subject: [PATCH 5/8] Initialize gRPC Server w/`0.0.0.0` --- src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 4efdbc58..0a2ca923 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,7 +196,7 @@ async fn main() -> Result<(), Error> { let config = config.clone(); async move { - let addr = format!("{}:{}", config.grpc_addr, config.grpc_port) + let addr = format!("0.0.0.0:{}", config.grpc_port) .parse() .expect("Failed to parse address"); From f70cc3a0171c8e987eb4865b0552c9f10a1bf0fd Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:13:28 -0700 Subject: [PATCH 6/8] Add `PushPool` Unit Tests --- src/config.rs | 52 +++++++++++++++++ src/push.rs | 153 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 200 insertions(+), 5 deletions(-) diff --git a/src/config.rs b/src/config.rs index e9c03dfe..2831cae1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -254,6 +254,12 @@ pub struct Config { /// The worker service endpoint. pub worker_endpoint: String, + + /// The hostname used to construct `callback_url` for task push requests. + pub callback_addr: String, + + /// The port used to construct `callback_url` for task push requests. + pub callback_port: u32, } impl Default for Config { @@ -328,6 +334,8 @@ impl Default for Config { push_threads: 1, push_queue_size: 1, worker_endpoint: "http://127.0.0.1:50052".into(), + callback_addr: "0.0.0.0".into(), + callback_port: 50051, } } } @@ -732,4 +740,48 @@ mod tests { Ok(()) }); } + + #[test] + fn test_default_push_callback_fields() { + let config = Config::default(); + assert_eq!(config.callback_addr, "0.0.0.0"); + assert_eq!(config.callback_port, 50051); + } + + #[test] + fn test_from_args_push_callback_fields_from_env() { + Jail::expect_with(|jail| { + jail.set_env("TASKBROKER_CALLBACK_ADDR", "127.0.0.1"); + jail.set_env("TASKBROKER_CALLBACK_PORT", "51000"); + + let args = Args { config: None }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "127.0.0.1"); + assert_eq!(config.callback_port, 51000); + + Ok(()) + }); + } + + #[test] + fn test_from_args_push_callback_fields_from_config_file() { + Jail::expect_with(|jail| { + jail.create_file( + "config.yaml", + r#" + callback_addr: 10.0.0.1 + callback_port: 52000 + "#, + )?; + + let args = Args { + config: Some("config.yaml".to_owned()), + }; + let config = Config::from_args(&args).unwrap(); + assert_eq!(config.callback_addr, "10.0.0.1"); + assert_eq!(config.callback_port, 52000); + + Ok(()) + }); + } } diff --git a/src/push.rs b/src/push.rs index 6ffa23bc..ee5aeb1d 100644 --- a/src/push.rs +++ b/src/push.rs @@ -7,12 +7,28 @@ use flume::{Receiver, Sender}; use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use tonic::async_trait; use tonic::transport::Channel; use tracing::{debug, error, info}; use crate::config::Config; use crate::store::inflight_activation::InflightActivation; +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +trait WorkerClient { + /// Send a single `PushTaskRequest` to the worker service. + async fn send(&mut self, request: PushTaskRequest) -> Result<()>; +} + +#[async_trait] +impl WorkerClient for WorkerServiceClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.push_task(request).await?; + Ok(()) + } +} + /// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. pub struct PushPool { /// The sending end of a channel that accepts task activations. @@ -44,7 +60,11 @@ impl PushPool { for _ in 0..self.config.push_threads { let endpoint = self.config.worker_endpoint.clone(); - let callback_url = format!("{}:{}", self.config.grpc_addr, self.config.grpc_port); + let callback_url = format!( + "{}:{}", + self.config.callback_addr, self.config.callback_port + ); + let receiver = self.receiver.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -113,8 +133,8 @@ impl PushPool { } /// Decode task activation and push it to a worker. -async fn push_task( - worker: &mut WorkerServiceClient, +async fn push_task( + worker: &mut W, activation: InflightActivation, callback_url: String, ) -> Result<()> { @@ -129,7 +149,7 @@ async fn push_task( callback_url, }; - let result = match worker.push_task(request).await { + let result = match worker.send(request).await { Ok(_) => { debug!("Successfully sent activation {id} to worker service!"); Ok(()) @@ -141,6 +161,129 @@ async fn push_task( } }; - metrics::histogram!("pusher.push_task.duration").record(start.elapsed()); + metrics::histogram!("push.push_task.duration").record(start.elapsed()); result } + +#[cfg(test)] +mod tests { + use anyhow::anyhow; + use std::sync::Arc; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::test_utils::make_activations; + + /// Fake worker client for unit testing. + struct MockWorkerClient { + /// Capture all received requests so we can assert things about them. + captured_requests: Vec, + + /// Should requests to the worker client fail? + should_fail: bool, + } + + impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + let captured_requests = vec![]; + + Self { + captured_requests, + should_fail, + } + } + } + + #[async_trait] + impl WorkerClient for MockWorkerClient { + async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + self.captured_requests.push(request); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + let callback_url = "taskbroker:50051".to_string(); + + let result = push_task(&mut worker, activation.clone(), callback_url.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_requests.len(), 1); + + let request = &worker.captured_requests[0]; + assert_eq!(request.callback_url, callback_url); + assert_eq!( + request.task.as_ref().map(|task| task.id.as_str()), + Some(activation.id.as_str()) + ); + } + + #[tokio::test] + async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_requests.is_empty(), + "worker should not be called if decode fails" + ); + } + + #[tokio::test] + async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = push_task(&mut worker, activation, "taskbroker:50051".to_string()).await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_requests.len(), 1); + } + + #[tokio::test] + async fn push_pool_submit_enqueues_item() { + let config = Arc::new(Config { + push_queue_size: 2, + ..Config::default() + }); + + let pool = PushPool::new(config); + let activation = make_activations(1).remove(0); + + let result = pool.submit(activation).await; + assert!(result.is_ok(), "submit should enqueue activation"); + } + + #[tokio::test] + async fn push_pool_submit_backpressures_when_queue_full() { + let config = Arc::new(Config { + push_queue_size: 1, + ..Config::default() + }); + + let pool = PushPool::new(config); + + let first = make_activations(1).remove(0); + let second = make_activations(1).remove(0); + + pool.submit(first) + .await + .expect("first submit should fill queue"); + + let second_submit = timeout(Duration::from_millis(50), pool.submit(second)).await; + assert!( + second_submit.is_err(), + "second submit should block when queue is full" + ); + } +} From dafa06c254dcea6f2b5b284263ed9e1715e670e4 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:35:07 -0700 Subject: [PATCH 7/8] Add `FetchPool` Unit Tests --- src/fetch.rs | 309 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 300 insertions(+), 9 deletions(-) diff --git a/src/fetch.rs b/src/fetch.rs index 10cfbbf4..94ffc1fa 100644 --- a/src/fetch.rs +++ b/src/fetch.rs @@ -4,35 +4,51 @@ use std::time::{Duration, Instant}; use anyhow::Result; use elegant_departure::get_shutdown_guard; use tokio::time::sleep; +use tonic::async_trait; use tracing::{debug, error, info}; use crate::config::Config; use crate::push::PushPool; +use crate::store::inflight_activation::InflightActivation; use crate::store::inflight_activation::InflightActivationStore; +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +#[async_trait] +pub trait TaskPusher { + /// Push a single task to the worker service. + async fn push_task(&self, activation: InflightActivation) -> Result<()>; +} + +#[async_trait] +impl TaskPusher for PushPool { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.submit(activation).await + } +} + /// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches a pending activation from the store, passes is to the push pool, and repeats. -pub struct FetchPool { +pub struct FetchPool { /// Inflight activation store. store: Arc, /// Pool of push threads that push activations to the worker service. - push_pool: Arc, + pusher: Arc, /// Taskbroker configuration. config: Arc, } -impl FetchPool { +impl FetchPool { /// Initialize a new fetch pool. pub fn new( store: Arc, config: Arc, - push_pool: Arc, + pusher: Arc, ) -> Self { Self { store, - push_pool, config, + pusher, } } @@ -44,7 +60,7 @@ impl FetchPool { let guard = get_shutdown_guard().shutdown_on_drop(); let store = self.store.clone(); - let push_pool = self.push_pool.clone(); + let task_pusher = self.pusher.clone(); let handle = tokio::spawn(async move { loop { @@ -56,7 +72,7 @@ impl FetchPool { _ = async { debug!("About to fetch next activation..."); - fetch_activations(store.clone(), push_pool.clone()).await; + fetch_activations(store.clone(), task_pusher.clone()).await; } => {} } } @@ -76,7 +92,10 @@ impl FetchPool { } /// Grab the next pending activation from the store, mark it as processing, and send to push channel. -pub async fn fetch_activations(store: Arc, push_pool: Arc) { +pub async fn fetch_activations( + store: Arc, + pusher: Arc, +) { let start = Instant::now(); metrics::counter!("fetch.fetch_activations.runs").increment(1); @@ -87,7 +106,7 @@ pub async fn fetch_activations(store: Arc, push_poo let id = activation.id.clone(); debug!("Atomically fetched and marked task {id} as processing"); - if let Err(e) = push_pool.submit(activation).await { + if let Err(e) = pusher.push_task(activation).await { error!("Failed to submit task {id} to push pool - {:?}", e); } @@ -109,3 +128,275 @@ pub async fn fetch_activations(store: Arc, push_poo } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use anyhow::{Error, anyhow}; + use chrono::{DateTime, Utc}; + use tokio::sync::Mutex; + use tokio::time::{Duration, timeout}; + + use super::*; + use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivationStatus, QueryResult, + }; + use crate::test_utils::make_activations; + + enum MockPendingResult { + Some(InflightActivation), + None, + Err, + } + + /// Fake store for testing. + struct MockStore { + /// How should all calls to `get_pending_activation` respond? + pending_result: MockPendingResult, + + /// How many calls to `get_pending_activation` have been performed? + pending_calls: AtomicUsize, + } + + impl MockStore { + fn new(pending_result: MockPendingResult) -> Self { + let pending_calls = AtomicUsize::new(0); + + Self { + pending_result, + pending_calls, + } + } + } + + #[async_trait] + impl InflightActivationStore for MockStore { + async fn vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn full_vacuum_db(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn db_size(&self) -> Result { + unimplemented!() + } + + async fn get_by_id(&self, _id: &str) -> Result, Error> { + unimplemented!() + } + async fn store(&self, _batch: Vec) -> Result { + unimplemented!() + } + + async fn get_pending_activation( + &self, + _application: Option<&str>, + _namespace: Option<&str>, + ) -> Result, Error> { + self.pending_calls.fetch_add(1, Ordering::SeqCst); + match &self.pending_result { + MockPendingResult::Some(activation) => Ok(Some(activation.clone())), + MockPendingResult::None => Ok(None), + MockPendingResult::Err => Err(anyhow!("mock store error")), + } + } + + async fn get_pending_activations_from_namespaces( + &self, + _application: Option<&str>, + _namespaces: Option<&[String]>, + _limit: Option, + ) -> Result, Error> { + unimplemented!() + } + + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { + unimplemented!() + } + + async fn count_by_status(&self, _status: InflightActivationStatus) -> Result { + unimplemented!() + } + + async fn count(&self) -> Result { + unimplemented!() + } + + async fn set_status( + &self, + _id: &str, + _status: InflightActivationStatus, + ) -> Result, Error> { + unimplemented!() + } + + async fn set_processing_deadline( + &self, + _id: &str, + _deadline: Option>, + ) -> Result<(), Error> { + unimplemented!() + } + + async fn delete_activation(&self, _id: &str) -> Result<(), Error> { + unimplemented!() + } + + async fn get_retry_activations(&self) -> Result, Error> { + unimplemented!() + } + + async fn clear(&self) -> Result<(), Error> { + unimplemented!() + } + + async fn handle_processing_deadline(&self) -> Result { + unimplemented!() + } + + async fn handle_processing_attempts(&self) -> Result { + unimplemented!() + } + + async fn handle_expires_at(&self) -> Result { + unimplemented!() + } + + async fn handle_delay_until(&self) -> Result { + unimplemented!() + } + + async fn handle_failed_tasks(&self) -> Result { + unimplemented!() + } + + async fn mark_completed(&self, _ids: Vec) -> Result { + unimplemented!() + } + + async fn remove_completed(&self) -> Result { + unimplemented!() + } + + async fn remove_killswitched( + &self, + _killswitched_tasks: Vec, + ) -> Result { + unimplemented!() + } + } + + /// Fake push pool for testing. + struct MockTaskPusher { + /// List of the IDs of all the activations that have been pushed. + pushed_ids: Mutex>, + + /// Should `push_task` fail? + should_fail: bool, + } + + impl MockTaskPusher { + fn new(should_fail: bool) -> Self { + let pushed_ids = Mutex::new(vec![]); + + Self { + pushed_ids, + should_fail, + } + } + } + + #[async_trait] + impl TaskPusher for MockTaskPusher { + async fn push_task(&self, activation: InflightActivation) -> Result<()> { + self.pushed_ids.lock().await.push(activation.id); + + if self.should_fail { + return Err(anyhow!("mock push error")); + } + + Ok(()) + } + } + + #[tokio::test] + async fn fetch_activations_submits_when_pending_exists() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation.clone()))); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1); + assert_eq!(pushed[0], activation.id); + } + + #[tokio::test] + async fn fetch_activations_logs_submit_error_but_does_not_fail() { + let activation = make_activations(1).remove(0); + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Some(activation))); + let pusher = Arc::new(MockTaskPusher::new(true)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert_eq!(pushed.len(), 1, "should attempt one push even if it fails"); + } + + #[tokio::test] + async fn fetch_activations_no_pending_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push if no activation is pending" + ); + } + + #[tokio::test] + async fn fetch_activations_store_error_does_not_submit() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::Err)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + fetch_activations(store, pusher.clone()).await; + + let pushed = pusher.pushed_ids.lock().await; + assert!( + pushed.is_empty(), + "should not push when pending activation lookup fails" + ); + } + + #[tokio::test] + async fn fetch_pool_start_spawns_at_least_one_worker_when_fetch_threads_zero() { + let store: Arc = + Arc::new(MockStore::new(MockPendingResult::None)); + let pusher = Arc::new(MockTaskPusher::new(false)); + + let config = Arc::new(Config { + fetch_threads: 0, + ..Config::default() + }); + + let pool = FetchPool::new(store, config, pusher); + + let result = timeout(Duration::from_millis(50), pool.start()).await; + assert!( + result.is_err(), + "start() should not complete immediately when fetch_threads is 0 because .max(1) starts one worker loop" + ); + } +} From e47a1e8ce1cfab0293c28394452730941475d58c Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 17 Mar 2026 19:39:45 -0700 Subject: [PATCH 8/8] Fix Linting --- src/fetch.rs | 1 + src/push.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fetch.rs b/src/fetch.rs index 94ffc1fa..dbd98132 100644 --- a/src/fetch.rs +++ b/src/fetch.rs @@ -145,6 +145,7 @@ mod tests { }; use crate::test_utils::make_activations; + #[allow(clippy::large_enum_variant)] enum MockPendingResult { Some(InflightActivation), None, diff --git a/src/push.rs b/src/push.rs index ee5aeb1d..3156ec78 100644 --- a/src/push.rs +++ b/src/push.rs @@ -157,7 +157,7 @@ async fn push_task( Err(e) => { error!("Could not push activation {id} to worker service - {:?}", e); - Err(e.into()) + Err(e) } };