From ea9f8ace9fcc3dae2fea338231cfa0e61688401a Mon Sep 17 00:00:00 2001 From: ahrav Date: Thu, 6 Jun 2024 07:58:08 -0700 Subject: [PATCH] [chore] - address comments (#2920) * address comments * fix test * address comments * update comments * fix tests * lint * do the thing --- pkg/channelmetrics/observablechan.go | 63 +++++++++++++---------- pkg/channelmetrics/observablechan_test.go | 18 ++++--- pkg/common/context.go | 19 +++++++ 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/pkg/channelmetrics/observablechan.go b/pkg/channelmetrics/observablechan.go index f3964bbf7..07197844c 100644 --- a/pkg/channelmetrics/observablechan.go +++ b/pkg/channelmetrics/observablechan.go @@ -23,9 +23,8 @@ type MetricsCollector interface { // It supports any type of channel and records metrics using a provided // MetricsCollector implementation. type ObservableChan[T any] struct { - ch chan T - metrics MetricsCollector - bufferCap int + ch chan T + metrics MetricsCollector } // NewObservableChan creates a new ObservableChan wrapping the provided channel. @@ -35,11 +34,13 @@ type ObservableChan[T any] struct { // the metric names. func NewObservableChan[T any](ch chan T, metrics MetricsCollector) *ObservableChan[T] { oChan := &ObservableChan[T]{ - ch: ch, - metrics: metrics, - bufferCap: cap(ch), + ch: ch, + metrics: metrics, } - oChan.RecordChannelCapacity() // Record capacity immediately + oChan.RecordChannelCapacity() + // Record the current length of the channel. + // Note: The channel is likely empty, but it may contain items if it was pre-existing. + oChan.RecordChannelLen() return oChan } @@ -51,34 +52,42 @@ func (oc *ObservableChan[T]) Close() { // Send sends an item into the channel and records the duration taken to do so. // It also updates the current size of the channel buffer. -func (oc *ObservableChan[T]) Send(ctx context.Context, item T) { - startTime := time.Now() - defer func() { - oc.metrics.RecordProduceDuration(time.Since(startTime)) +// This method blocks until the item is sent. +func (oc *ObservableChan[T]) Send(item T) { _ = oc.SendCtx(context.Background(), item) } + +// SendCtx sends an item into the channel with context and records the duration taken to do so. +// It also updates the current size of the channel buffer and supports context cancellation. +func (oc *ObservableChan[T]) SendCtx(ctx context.Context, item T) error { + defer func(start time.Time) { + oc.metrics.RecordProduceDuration(time.Since(start)) oc.RecordChannelLen() - }() - if err := common.CancellableWrite(ctx, oc.ch, item); err != nil { - ctx.Logger().Error(err, "failed to write item to observable channel") - } + }(time.Now()) + + return common.CancellableWrite(ctx, oc.ch, item) } // Recv receives an item from the channel and records the duration taken to do so. // It also updates the current size of the channel buffer. -func (oc *ObservableChan[T]) Recv(_ context.Context) T { - startTime := time.Now() - defer func() { - oc.metrics.RecordConsumeDuration(time.Since(startTime)) +// This method blocks until an item is available. +func (oc *ObservableChan[T]) Recv() T { + v, _ := oc.RecvCtx(context.Background()) + return v +} + +// RecvCtx receives an item from the channel with context and records the duration taken to do so. +// It also updates the current size of the channel buffer and supports context cancellation. +// If an error occurs, it logs the error. +func (oc *ObservableChan[T]) RecvCtx(ctx context.Context) (T, error) { + defer func(start time.Time) { + oc.metrics.RecordConsumeDuration(time.Since(start)) oc.RecordChannelLen() - }() - return <-oc.ch + }(time.Now()) + + return common.CancellableRecv(ctx, oc.ch) } // RecordChannelCapacity records the capacity of the channel buffer. -func (oc *ObservableChan[T]) RecordChannelCapacity() { - oc.metrics.RecordChannelCap(oc.bufferCap) -} +func (oc *ObservableChan[T]) RecordChannelCapacity() { oc.metrics.RecordChannelCap(cap(oc.ch)) } // RecordChannelLen records the current size of the channel buffer. -func (oc *ObservableChan[T]) RecordChannelLen() { - oc.metrics.RecordChannelLen(len(oc.ch)) -} +func (oc *ObservableChan[T]) RecordChannelLen() { oc.metrics.RecordChannelLen(len(oc.ch)) } diff --git a/pkg/channelmetrics/observablechan_test.go b/pkg/channelmetrics/observablechan_test.go index 920431941..e8fa7aac1 100644 --- a/pkg/channelmetrics/observablechan_test.go +++ b/pkg/channelmetrics/observablechan_test.go @@ -27,14 +27,15 @@ func TestObservableChanSend(t *testing.T) { bufferCap := 10 mockMetrics.On("RecordProduceDuration", mock.Anything).Once() - mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Once() + mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice() mockMetrics.On("RecordChannelCap", bufferCap).Once() ch := make(chan int, bufferCap) oc := NewObservableChan(ch, mockMetrics) assert.Equal(t, bufferCap, cap(oc.ch)) - oc.Send(context.Background(), 1) + err := oc.SendCtx(context.Background(), 1) + assert.NoError(t, err) mockMetrics.AssertExpectations(t) } @@ -47,7 +48,7 @@ func TestObservableChanRecv(t *testing.T) { mockMetrics.On("RecordConsumeDuration", mock.Anything).Once() // For the send mockMetrics.On("RecordProduceDuration", mock.Anything).Once() - mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice() // For the send and recv + mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Times(3) // For the send and recv mockMetrics.On("RecordChannelCap", bufferCap).Once() ch := make(chan int, bufferCap) @@ -55,12 +56,14 @@ func TestObservableChanRecv(t *testing.T) { assert.Equal(t, bufferCap, cap(oc.ch)) go func() { - oc.Send(context.Background(), 1) + err := oc.SendCtx(context.Background(), 1) + assert.NoError(t, err) }() time.Sleep(100 * time.Millisecond) // Ensure Send happens before Recv - oc.Recv(context.Background()) + _, err := oc.RecvCtx(context.Background()) + assert.NoError(t, err) mockMetrics.AssertExpectations(t) } @@ -72,6 +75,7 @@ func TestObservableChanRecordChannelCapacity(t *testing.T) { bufferCap := 10 mockMetrics.On("RecordChannelCap", bufferCap).Twice() + mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Once() ch := make(chan int, bufferCap) oc := NewObservableChan(ch, mockMetrics) @@ -87,7 +91,7 @@ func TestObservableChanRecordChannelLen(t *testing.T) { mockMetrics := new(MockMetricsCollector) bufferCap := 10 - mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Once() + mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice() mockMetrics.On("RecordChannelCap", bufferCap).Once() ch := make(chan int, bufferCap) @@ -105,7 +109,7 @@ func TestObservableChan_Close(t *testing.T) { bufferCap := 1 mockMetrics.On("RecordChannelCap", bufferCap).Once() - mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Once() + mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice() ch := make(chan int, bufferCap) oc := NewObservableChan(ch, mockMetrics) diff --git a/pkg/common/context.go b/pkg/common/context.go index 775494404..7c8a0aa20 100644 --- a/pkg/common/context.go +++ b/pkg/common/context.go @@ -27,3 +27,22 @@ func CancellableWrite[T any](ctx context.Context, ch chan<- T, item T) error { } } } + +// CancellableRecv blocks on receiving an item from the channel but can be +// cancelled by the context. If both the context is cancelled and the channel +// read would succeed, either operation will be performed randomly. +func CancellableRecv[T any](ctx context.Context, ch <-chan T) (T, error) { + var zero T // zero value of type T + + select { + case <-ctx.Done(): // priority to context cancellation + return zero, ctx.Err() + default: + select { + case <-ctx.Done(): + return zero, ctx.Err() + case item := <-ch: + return item, nil + } + } +}