trufflehog/pkg/channelmetrics/observablechan_test.go

139 lines
3.4 KiB
Go

package channelmetrics
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
type MockMetricsCollector struct{ mock.Mock }
func (m *MockMetricsCollector) RecordProduceDuration(duration time.Duration) { m.Called(duration) }
func (m *MockMetricsCollector) RecordConsumeDuration(duration time.Duration) { m.Called(duration) }
func (m *MockMetricsCollector) RecordChannelLen(size int) { m.Called(size) }
func (m *MockMetricsCollector) RecordChannelCap(capacity int) { m.Called(capacity) }
func TestObservableChanSend(t *testing.T) {
t.Parallel()
mockMetrics := new(MockMetricsCollector)
bufferCap := 10
mockMetrics.On("RecordProduceDuration", mock.Anything).Once()
mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice()
mockMetrics.On("RecordChannelCap", bufferCap).Once()
ch := make(chan int, bufferCap)
oc := NewObservableChan(ch, mockMetrics)
assert.Equal(t, bufferCap, cap(oc.ch))
err := oc.SendCtx(context.Background(), 1)
assert.NoError(t, err)
mockMetrics.AssertExpectations(t)
}
func TestObservableChanRecv(t *testing.T) {
t.Parallel()
mockMetrics := new(MockMetricsCollector)
bufferCap := 10
mockMetrics.On("RecordConsumeDuration", mock.Anything).Once() // For the send
mockMetrics.On("RecordProduceDuration", mock.Anything).Once()
mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Times(3) // For the send and recv
mockMetrics.On("RecordChannelCap", bufferCap).Once()
ch := make(chan int, bufferCap)
oc := NewObservableChan(ch, mockMetrics)
assert.Equal(t, bufferCap, cap(oc.ch))
go func() {
err := oc.SendCtx(context.Background(), 1)
assert.NoError(t, err)
}()
time.Sleep(100 * time.Millisecond) // Ensure Send happens before Recv
_, err := oc.RecvCtx(context.Background())
assert.NoError(t, err)
mockMetrics.AssertExpectations(t)
}
func TestObservableChanRecordChannelCapacity(t *testing.T) {
t.Parallel()
mockMetrics := new(MockMetricsCollector)
bufferCap := 10
mockMetrics.On("RecordChannelCap", bufferCap).Twice()
mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Once()
ch := make(chan int, bufferCap)
oc := NewObservableChan(ch, mockMetrics)
oc.RecordChannelCapacity()
mockMetrics.AssertExpectations(t)
}
func TestObservableChanRecordChannelLen(t *testing.T) {
t.Parallel()
mockMetrics := new(MockMetricsCollector)
bufferCap := 10
mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice()
mockMetrics.On("RecordChannelCap", bufferCap).Once()
ch := make(chan int, bufferCap)
oc := NewObservableChan(ch, mockMetrics)
oc.RecordChannelLen()
mockMetrics.AssertExpectations(t)
}
func TestObservableChan_Close(t *testing.T) {
t.Parallel()
mockMetrics := new(MockMetricsCollector)
bufferCap := 1
mockMetrics.On("RecordChannelCap", bufferCap).Once()
mockMetrics.On("RecordChannelLen", mock.AnythingOfType("int")).Twice()
ch := make(chan int, bufferCap)
oc := NewObservableChan(ch, mockMetrics)
oc.Close()
mockMetrics.AssertExpectations(t)
}
func TestObservableChanClosed(t *testing.T) {
t.Parallel()
ch := make(chan int)
close(ch)
oc := NewObservableChan(ch, nil)
ctx, cancel := context.WithCancel(context.Background())
// Closed channel should return with an error.
v, err := oc.RecvCtx(ctx)
assert.Error(t, err)
assert.Equal(t, 0, v)
// Cancelled context should also return with an error.
cancel()
_, err = oc.RecvCtx(ctx)
assert.Error(t, err)
}