Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions internal/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package consumer

import (
"context"
"sync"

"github.com/hookdeck/outpost/internal/logging"
"github.com/hookdeck/outpost/internal/mqs"
Expand Down Expand Up @@ -67,8 +68,53 @@ var _ Consumer = &consumerImpl{}
func (c *consumerImpl) Run(ctx context.Context) error {
defer c.subscription.Shutdown(ctx)

tracerProvider := otel.GetTracerProvider()
tracer := tracerProvider.Tracer("github.com/hookdeck/outpost/internal/consumer")
// If the subscription manages its own concurrency (e.g. GCP native SDK
// with MaxOutstandingMessages), skip the consumer-side semaphore.
if cs, ok := c.subscription.(mqs.ConcurrentSubscription); ok && cs.SupportsConcurrency() {
return c.runConcurrent(ctx)
}
return c.runWithSemaphore(ctx)
}

// runConcurrent is used when the subscription manages flow control internally.
// A WaitGroup tracks in-flight handlers for graceful shutdown.
func (c *consumerImpl) runConcurrent(ctx context.Context) error {
tracer := otel.GetTracerProvider().Tracer("github.com/hookdeck/outpost/internal/consumer")

var wg sync.WaitGroup
var subscriptionErr error

recvLoop:
for {
msg, err := c.subscription.Receive(ctx)
if err != nil {
subscriptionErr = err
break recvLoop
}

wg.Add(1)
go func() {
defer wg.Done()

handlerCtx, span := tracer.Start(context.Background(), c.actionWithName("Consumer.Handle"))
defer span.End()

if err := c.handler.Handle(handlerCtx, msg); err != nil {
span.RecordError(err)
if c.logger != nil {
c.logger.Ctx(handlerCtx).Error("consumer handler error", zap.String("name", c.name), zap.Error(err))
}
}
}()
}

wg.Wait()
return subscriptionErr
}

// runWithSemaphore limits concurrency via a channel-based semaphore.
func (c *consumerImpl) runWithSemaphore(ctx context.Context) error {
tracer := otel.GetTracerProvider().Tracer("github.com/hookdeck/outpost/internal/consumer")

var subscriptionErr error

Expand All @@ -93,8 +139,7 @@ recvLoop:
handlerCtx, span := tracer.Start(context.Background(), c.actionWithName("Consumer.Handle"))
defer span.End()

err = c.handler.Handle(handlerCtx, msg)
if err != nil {
if err := c.handler.Handle(handlerCtx, msg); err != nil {
span.RecordError(err)
if c.logger != nil {
c.logger.Ctx(handlerCtx).Error("consumer handler error", zap.String("name", c.name), zap.Error(err))
Expand Down
4 changes: 2 additions & 2 deletions internal/deliverymq/deliverymq.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ func (q *DeliveryMQ) Publish(ctx context.Context, task models.DeliveryTask) erro
return q.queue.Publish(ctx, &task)
}

func (q *DeliveryMQ) Subscribe(ctx context.Context) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx)
func (q *DeliveryMQ) Subscribe(ctx context.Context, opts ...mqs.SubscribeOption) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx, opts...)
}
4 changes: 2 additions & 2 deletions internal/logmq/logmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ func (q *LogMQ) Publish(ctx context.Context, entry models.LogEntry) error {
return q.queue.Publish(ctx, &entry)
}

func (q *LogMQ) Subscribe(ctx context.Context) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx)
func (q *LogMQ) Subscribe(ctx context.Context, opts ...mqs.SubscribeOption) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx, opts...)
}
35 changes: 32 additions & 3 deletions internal/mqs/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,43 @@ type InMemoryConfig struct {
type Queue interface {
Init(ctx context.Context) (func(), error)
Publish(ctx context.Context, msg IncomingMessage) error
Subscribe(ctx context.Context) (Subscription, error)
Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error)
}

type Subscription interface {
Receive(ctx context.Context) (*Message, error)
Shutdown(ctx context.Context) error
}

// ConcurrentSubscription indicates a subscription that manages its own concurrency
// internally (e.g. via SDK flow control). When true, the consumer should skip its
// own semaphore-based concurrency limiting.
type ConcurrentSubscription interface {
SupportsConcurrency() bool
}

// SubscribeOption configures subscription behavior.
type SubscribeOption func(*SubscribeOptions)

// SubscribeOptions holds options for Subscribe.
type SubscribeOptions struct {
Concurrency int
}

// WithConcurrency sets the max in-flight messages for the subscription.
func WithConcurrency(n int) SubscribeOption {
return func(o *SubscribeOptions) { o.Concurrency = n }
}

// ApplySubscribeOptions applies all options and returns the result.
func ApplySubscribeOptions(opts []SubscribeOption) SubscribeOptions {
var o SubscribeOptions
for _, opt := range opts {
opt(&o)
}
return o
}

type QueueMessage interface {
Ack()
Nack()
Expand Down Expand Up @@ -81,7 +110,7 @@ func (q *UnimplementedQueue) Publish(ctx context.Context, msg IncomingMessage) e
return errors.New("unimplemented")
}

func (q *UnimplementedQueue) Subscribe(ctx context.Context) (Subscription, error) {
func (q *UnimplementedQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
return nil, errors.New("unimplemented")
}

Expand All @@ -108,7 +137,7 @@ func (q *InMemoryQueue) Publish(ctx context.Context, incomingMessage IncomingMes
return q.base.Publish(ctx, q.topic, incomingMessage, nil)
}

func (q *InMemoryQueue) Subscribe(ctx context.Context) (Subscription, error) {
func (q *InMemoryQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
subscription, err := pubsub.OpenSubscription(ctx, q.topicName)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/mqs/queue_awssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (q *AWSQueue) Publish(ctx context.Context, incomingMessage IncomingMessage)
return q.base.Publish(ctx, q.topic, incomingMessage, nil)
}

func (q *AWSQueue) Subscribe(ctx context.Context) (Subscription, error) {
func (q *AWSQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
var err error
q.once.Do(func() {
err = q.InitSDK(ctx)
Expand Down
2 changes: 1 addition & 1 deletion internal/mqs/queue_azureservicebus.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (q *AzureServiceBusQueue) Publish(ctx context.Context, incomingMessage Inco
return q.base.Publish(ctx, q.topic, incomingMessage, nil)
}

func (q *AzureServiceBusQueue) Subscribe(ctx context.Context) (Subscription, error) {
func (q *AzureServiceBusQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
var err error
q.once.Do(func() {
err = q.InitClient(ctx)
Expand Down
131 changes: 104 additions & 27 deletions internal/mqs/queue_gcppubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"fmt"
"sync"
"time"

nativepubsub "cloud.google.com/go/pubsub"
"gocloud.dev/gcp"
"gocloud.dev/pubsub"
"gocloud.dev/pubsub/gcppubsub"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -120,43 +123,117 @@ func (q *GCPPubSubQueue) Publish(ctx context.Context, incomingMessage IncomingMe
return q.base.Publish(ctx, q.topic, incomingMessage, nil)
}

func (q *GCPPubSubQueue) Subscribe(ctx context.Context) (Subscription, error) {
var err error
var subscription *pubsub.Subscription
func (q *GCPPubSubQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
o := ApplySubscribeOptions(opts)
concurrency := o.Concurrency

var clientOpts []option.ClientOption
if q.config.ServiceAccountCredentials != "" {
subscription, err = q.createSubscriptionWithCredentials(ctx)
} else {
subscription, err = q.createSubscriptionWithoutCredentials(ctx)
creds, err := google.CredentialsFromJSON(ctx, []byte(q.config.ServiceAccountCredentials), "https://www.googleapis.com/auth/pubsub")
if err != nil {
return nil, fmt.Errorf("parse credentials: %w", err)
}
clientOpts = append(clientOpts, option.WithCredentials(creds))
}

client, err := nativepubsub.NewClient(ctx, q.config.ProjectID, clientOpts...)
if err != nil {
return nil, err
return nil, fmt.Errorf("create pubsub client: %w", err)
}

sub := client.Subscription(q.config.SubscriptionID)
sub.ReceiveSettings.MaxOutstandingMessages = concurrency
// Use a single StreamingPull stream per subscription to keep concurrency
// control explicit; scaling is done at the subscription level, not via
// additional goroutines within a subscription.
sub.ReceiveSettings.NumGoroutines = 1
// Disable automatic lease extension so messages are not held beyond the
// subscription's ack deadline. We are intentional about consumer processing
// logic and do not want the SDK silently extending message leases — if a
// handler exceeds the ack deadline, the message should be redelivered.
sub.ReceiveSettings.MaxExtension = -1 * time.Second

msgChan := make(chan *Message, concurrency)
subCtx, cancel := context.WithCancel(ctx)
done := make(chan struct{})

s := &gcpNativeSubscription{
msgChan: msgChan,
cancel: cancel,
done: done,
client: client,
}
return q.base.Subscribe(ctx, subscription)

go func() {
defer close(done)
defer close(msgChan)
// sub.Receive blocks until subCtx is cancelled or a fatal error occurs.
// The callback nacks on context cancellation to avoid buffering messages
// that won't be processed.
s.recvErr = sub.Receive(subCtx, func(_ context.Context, msg *nativepubsub.Message) {
m := &Message{
QueueMessage: &gcpNativeAcker{msg: msg},
LoggableID: msg.ID,
Body: msg.Data,
}
select {
case msgChan <- m:
case <-subCtx.Done():
msg.Nack()
}
})
}()

return s, nil
}

func (q *GCPPubSubQueue) createSubscriptionWithCredentials(ctx context.Context) (*pubsub.Subscription, error) {
conn, err := q.getConn(ctx)
if err != nil {
return nil, err
// gcpNativeSubscription bridges the native SDK StreamingPull to the mqs.Subscription interface.
type gcpNativeSubscription struct {
msgChan <-chan *Message
cancel context.CancelFunc
done chan struct{}
client *nativepubsub.Client
recvErr error // set by the background goroutine when sub.Receive exits
}

var _ Subscription = &gcpNativeSubscription{}
var _ ConcurrentSubscription = &gcpNativeSubscription{}

func (s *gcpNativeSubscription) Receive(ctx context.Context) (*Message, error) {
select {
case msg, ok := <-s.msgChan:
if !ok {
if s.recvErr != nil {
return nil, fmt.Errorf("subscription closed: %w", s.recvErr)
}
return nil, fmt.Errorf("subscription closed")
}
return msg, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

subClient, err := gcppubsub.SubscriberClient(ctx, conn)
if err != nil {
return nil, err
func (s *gcpNativeSubscription) Shutdown(_ context.Context) error {
s.cancel()
<-s.done
// Nack any remaining buffered messages for faster redelivery.
for msg := range s.msgChan {
msg.Nack()
}
q.cleanupFns = append(q.cleanupFns, func() {
subClient.Close()
})
return s.client.Close()
}

subscription := gcppubsub.OpenSubscription(subClient, gcp.ProjectID(q.config.ProjectID), q.config.SubscriptionID, nil)
return subscription, nil
// SupportsConcurrency returns true — the native SDK manages concurrency via
// MaxOutstandingMessages, so the consumer should skip its own semaphore.
func (s *gcpNativeSubscription) SupportsConcurrency() bool {
return true
}

func (q *GCPPubSubQueue) createSubscriptionWithoutCredentials(ctx context.Context) (*pubsub.Subscription, error) {
subscription, err := pubsub.OpenSubscription(ctx,
fmt.Sprintf("gcppubsub://projects/%s/subscriptions/%s", q.config.ProjectID, q.config.SubscriptionID))
if err != nil {
return nil, err
}
return subscription, nil
// gcpNativeAcker wraps a native SDK message to implement QueueMessage.
type gcpNativeAcker struct {
msg *nativepubsub.Message
}

func (a *gcpNativeAcker) Ack() { a.msg.Ack() }
func (a *gcpNativeAcker) Nack() { a.msg.Nack() }
2 changes: 1 addition & 1 deletion internal/mqs/queue_rabbitmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (q *RabbitMQQueue) Publish(ctx context.Context, incomingMessage IncomingMes
})
}

func (q *RabbitMQQueue) Subscribe(ctx context.Context) (Subscription, error) {
func (q *RabbitMQQueue) Subscribe(ctx context.Context, opts ...SubscribeOption) (Subscription, error) {
var err error
q.once.Do(func() {
err = q.InitConn()
Expand Down
4 changes: 2 additions & 2 deletions internal/publishmq/publishmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ func New(opts ...func(opts *PublishMQOption)) *PublishMQ {
}
}

func (q *PublishMQ) Subscribe(ctx context.Context) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx)
func (q *PublishMQ) Subscribe(ctx context.Context, opts ...mqs.SubscribeOption) (mqs.Subscription, error) {
return q.queue.Subscribe(ctx, opts...)
}
6 changes: 3 additions & 3 deletions internal/services/consumer_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// It handles subscription at runtime and consistent error handling for graceful shutdowns.
type ConsumerWorker struct {
name string
subscribe func(ctx context.Context) (mqs.Subscription, error)
subscribe func(ctx context.Context, opts ...mqs.SubscribeOption) (mqs.Subscription, error)
handler consumer.MessageHandler
concurrency int
logger *logging.Logger
Expand All @@ -24,7 +24,7 @@ type ConsumerWorker struct {
// NewConsumerWorker creates a new generic consumer worker.
func NewConsumerWorker(
name string,
subscribe func(ctx context.Context) (mqs.Subscription, error),
subscribe func(ctx context.Context, opts ...mqs.SubscribeOption) (mqs.Subscription, error),
handler consumer.MessageHandler,
concurrency int,
logger *logging.Logger,
Expand All @@ -48,7 +48,7 @@ func (w *ConsumerWorker) Run(ctx context.Context) error {
logger := w.logger.Ctx(ctx)
logger.Info("consumer worker starting", zap.String("name", w.name))

subscription, err := w.subscribe(ctx)
subscription, err := w.subscribe(ctx, mqs.WithConcurrency(w.concurrency))
if err != nil {
logger.Error("error subscribing", zap.String("name", w.name), zap.Error(err))
return err
Expand Down
Loading