diff --git a/pkg/sources/source_manager.go b/pkg/sources/source_manager.go index 923f04648..c1c9c3a34 100644 --- a/pkg/sources/source_manager.go +++ b/pkg/sources/source_manager.go @@ -19,7 +19,7 @@ type handle int64 // SourceInitFunc is a function that takes a source and job ID and returns an // initialized Source. -type SourceInitFunc func(ctx context.Context, sourceID int64, jobID int64) (Source, error) +type SourceInitFunc func(ctx context.Context, jobID, sourceID int64) (Source, error) // sourceInfo is an aggregate struct to store source information provided on // initialization. @@ -228,7 +228,7 @@ func (s *SourceManager) run(ctx context.Context, handle handle, jobID int64, rep report.ReportError(Fatal{err}) return Fatal{err} } - source, err := sourceInfo.initFunc(ctx, int64(handle), jobID) + source, err := sourceInfo.initFunc(ctx, jobID, int64(handle)) if err != nil { report.ReportError(Fatal{err}) return Fatal{err} diff --git a/pkg/sources/source_manager_test.go b/pkg/sources/source_manager_test.go index 266263343..4d66dbc7c 100644 --- a/pkg/sources/source_manager_test.go +++ b/pkg/sources/source_manager_test.go @@ -245,3 +245,42 @@ func TestSourceManagerContextCancelled(t *testing.T) { report := ref.Snapshot() assert.Error(t, report.FatalError()) } + +type DummyAPI struct { + registerSource func(context.Context, string, sourcespb.SourceType) (int64, error) + getJobID func(context.Context, int64) (int64, error) +} + +func (api DummyAPI) RegisterSource(ctx context.Context, name string, kind sourcespb.SourceType) (int64, error) { + return api.registerSource(ctx, name, kind) +} + +func (api DummyAPI) GetJobID(ctx context.Context, id int64) (int64, error) { + return api.getJobID(ctx, id) +} + +func TestSourceManagerJobAndSourceIDs(t *testing.T) { + mgr := NewManager(WithAPI(DummyAPI{ + registerSource: func(context.Context, string, sourcespb.SourceType) (int64, error) { + return 1337, nil + }, + getJobID: func(context.Context, int64) (int64, error) { + return 9001, nil + }, + })) + var ( + initializedJobID int64 + initializedSourceID int64 + ) + handle, err := mgr.Enroll(context.Background(), "dummy", 1337, + func(ctx context.Context, jobID, sourceID int64) (Source, error) { + initializedJobID = jobID + initializedSourceID = sourceID + return nil, fmt.Errorf("ignore") + }) + assert.NoError(t, err) + + _, _ = mgr.Run(context.Background(), handle) + assert.Equal(t, int64(1337), initializedSourceID) + assert.Equal(t, int64(9001), initializedJobID) +}