diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index da7156ac7..9247dfc83 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -2,11 +2,11 @@ package cache // Cache is used to store key/value pairs. -type Cache interface { +type Cache[T any] interface { // Set stores the given key/value pair. - Set(string, any) + Set(string, T) // Get returns the value for the given key and a boolean indicating if the key was found. - Get(string) (any, bool) + Get(string) (T, bool) // Exists returns true if the given key exists in the cache. Exists(string) bool // Delete the given key from the cache. @@ -18,7 +18,7 @@ type Cache interface { // Keys returns all keys in the cache. Keys() []string // Values returns all values in the cache. - Values() []any + Values() []T // Contents returns all keys in the cache encoded as a string. Contents() string } diff --git a/pkg/cache/memory/memory.go b/pkg/cache/memory/memory.go index ce5bd4365..f38b67fdf 100644 --- a/pkg/cache/memory/memory.go +++ b/pkg/cache/memory/memory.go @@ -14,46 +14,46 @@ const ( ) // Cache wraps the go-cache library to provide an in-memory key-value store. -type Cache struct { +type Cache[T any] struct { c *cache.Cache expiration time.Duration purgeInterval time.Duration } // CacheOption defines a function type used for configuring a Cache. -type CacheOption func(*Cache) +type CacheOption[T any] func(*Cache[T]) // WithExpirationInterval returns a CacheOption to set the expiration interval of cache items. // The interval determines the duration a cached item remains in the cache before it is expired. -func WithExpirationInterval(interval time.Duration) CacheOption { - return func(c *Cache) { c.expiration = interval } +func WithExpirationInterval[T any](interval time.Duration) CacheOption[T] { + return func(c *Cache[T]) { c.expiration = interval } } // WithPurgeInterval returns a CacheOption to set the interval at which the cache purges expired items. // Regular purging helps in freeing up memory by removing stale entries. -func WithPurgeInterval(interval time.Duration) CacheOption { - return func(c *Cache) { c.purgeInterval = interval } +func WithPurgeInterval[T any](interval time.Duration) CacheOption[T] { + return func(c *Cache[T]) { c.purgeInterval = interval } } // New constructs a new in-memory cache instance with optional configurations. // By default, it sets the expiration and purge intervals to 12 and 13 hours, respectively. // These defaults can be overridden using the functional options: WithExpirationInterval and WithPurgeInterval. -func New(opts ...CacheOption) *Cache { - return NewWithData(nil, opts...) +func New[T any](opts ...CacheOption[T]) *Cache[T] { + return NewWithData[T](nil, opts...) } // CacheEntry represents a single entry in the cache, consisting of a key and its corresponding value. -type CacheEntry struct { +type CacheEntry[T any] struct { // Key is the unique identifier for the entry. Key string // Value is the data stored in the entry. - Value any + Value T } // NewWithData constructs a new in-memory cache with existing data. // It also accepts CacheOption parameters to override default configuration values. -func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache { - instance := &Cache{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval} +func NewWithData[T any](data []CacheEntry[T], opts ...CacheOption[T]) *Cache[T] { + instance := &Cache[T]{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval} for _, opt := range opts { opt(instance) } @@ -69,38 +69,46 @@ func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache { } // Set adds a key-value pair to the cache. -func (c *Cache) Set(key string, value any) { +func (c *Cache[T]) Set(key string, value T) { c.c.Set(key, value, defaultExpiration) } // Get returns the value for the given key. -func (c *Cache) Get(key string) (any, bool) { - return c.c.Get(key) +func (c *Cache[T]) Get(key string) (T, bool) { + var value T + + v, ok := c.c.Get(key) + if !ok { + return value, false + } + + value, ok = v.(T) + return value, ok } // Exists returns true if the given key exists in the cache. -func (c *Cache) Exists(key string) bool { +func (c *Cache[T]) Exists(key string) bool { _, ok := c.c.Get(key) return ok } // Delete removes the key-value pair from the cache. -func (c *Cache) Delete(key string) { +func (c *Cache[T]) Delete(key string) { c.c.Delete(key) } // Clear removes all key-value pairs from the cache. -func (c *Cache) Clear() { +func (c *Cache[T]) Clear() { c.c.Flush() } // Count returns the number of key-value pairs in the cache. -func (c *Cache) Count() int { +func (c *Cache[T]) Count() int { return c.c.ItemCount() } // Keys returns all keys in the cache. -func (c *Cache) Keys() []string { +func (c *Cache[T]) Keys() []string { items := c.c.Items() res := make([]string, 0, len(items)) for k := range items { @@ -110,17 +118,20 @@ func (c *Cache) Keys() []string { } // Values returns all values in the cache. -func (c *Cache) Values() []any { +func (c *Cache[T]) Values() []T { items := c.c.Items() - res := make([]any, 0, len(items)) + res := make([]T, 0, len(items)) for _, v := range items { - res = append(res, v.Object) + obj, ok := v.Object.(T) + if ok { + res = append(res, obj) + } } return res } // Contents returns a comma-separated string containing all keys in the cache. -func (c *Cache) Contents() string { +func (c *Cache[T]) Contents() string { items := c.c.Items() res := make([]string, 0, len(items)) for k := range items { diff --git a/pkg/cache/memory/memory_test.go b/pkg/cache/memory/memory_test.go index 0ec9c926b..c4e167a57 100644 --- a/pkg/cache/memory/memory_test.go +++ b/pkg/cache/memory/memory_test.go @@ -10,7 +10,7 @@ import ( ) func TestCache(t *testing.T) { - c := New() + c := New[string]() // Test set and get. c.Set("key1", "key1") @@ -32,7 +32,7 @@ func TestCache(t *testing.T) { // Test delete. c.Delete("key1") v, ok = c.Get("key1") - if ok || v != nil { + if ok || v != "" { t.Fatalf("Unexpected value for key1 after delete: %v, %v", v, ok) } @@ -40,7 +40,7 @@ func TestCache(t *testing.T) { c.Set("key10", "key10") c.Clear() v, ok = c.Get("key10") - if ok || v != nil { + if ok || v != "" { t.Fatalf("Unexpected value for key10 after clear: %v, %v", v, ok) } @@ -59,9 +59,7 @@ func TestCache(t *testing.T) { // Test getting only the values. vals := make([]string, 0, c.Count()) - for _, v := range c.Values() { - vals = append(vals, v.(string)) - } + vals = append(vals, c.Values()...) sort.Strings(vals) sort.Strings(values) if !cmp.Equal(values, vals) { @@ -83,7 +81,7 @@ func TestCache(t *testing.T) { } func TestCache_NewWithData(t *testing.T) { - data := []CacheEntry{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}} + data := []CacheEntry[string]{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}} c := NewWithData(data) // Test the count. @@ -106,10 +104,10 @@ func TestCache_NewWithData(t *testing.T) { } } -func setupBenchmarks(b *testing.B) *Cache { +func setupBenchmarks(b *testing.B) *Cache[string] { b.Helper() - c := New() + c := New[string]() for i := 0; i < 500_000; i++ { key := fmt.Sprintf("key%d", i) @@ -120,7 +118,7 @@ func setupBenchmarks(b *testing.B) *Cache { } func BenchmarkSet(b *testing.B) { - c := New() + c := New[string]() for i := 0; i < b.N; i++ { key := fmt.Sprintf("key%d", i) diff --git a/pkg/sources/gcs/gcs.go b/pkg/sources/gcs/gcs.go index 52ecbf7e6..3aafa6dc6 100644 --- a/pkg/sources/gcs/gcs.go +++ b/pkg/sources/gcs/gcs.go @@ -81,11 +81,11 @@ type Source struct { // at given increments. type persistableCache struct { persistIncrement int - cache.Cache + cache.Cache[string] *sources.Progress } -func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache { +func newPersistableCache(increment int, cache cache.Cache[string], p *sources.Progress) *persistableCache { return &persistableCache{ persistIncrement: increment, Cache: cache, @@ -95,7 +95,7 @@ func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) // Set overrides the cache Set method of the cache to enable the persistence // of the cache contents the Progress of the source at given increments. -func (c *persistableCache) Set(key string, val any) { +func (c *persistableCache) Set(key string, val string) { c.Cache.Set(key, val) if ok, contents := c.shouldPersist(); ok { c.Progress.EncodedResumeInfo = contents @@ -293,18 +293,18 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ . } func (s *Source) setupCache(ctx context.Context) *persistableCache { - var c cache.Cache + var c cache.Cache[string] if s.Progress.EncodedResumeInfo != "" { keys := strings.Split(s.Progress.EncodedResumeInfo, ",") - entries := make([]memory.CacheEntry, len(keys)) + entries := make([]memory.CacheEntry[string], len(keys)) for i, val := range keys { - entries[i] = memory.CacheEntry{Key: val, Value: val} + entries[i] = memory.CacheEntry[string]{Key: val, Value: val} } - c = memory.NewWithData(entries) + c = memory.NewWithData[string](entries) ctx.Logger().V(3).Info("Loaded cache", "num_entries", len(entries)) } else { - c = memory.New() + c = memory.New[string]() } // TODO (ahrav): Make this configurable via conn. @@ -312,7 +312,7 @@ func (s *Source) setupCache(ctx context.Context) *persistableCache { return persistCache } -func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache) { +func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache[string]) { s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 7b24d6ee5..acffdec1b 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -57,7 +57,7 @@ type Source struct { sourceID sources.SourceID jobID sources.JobID verify bool - orgsCache cache.Cache + orgsCache cache.Cache[string] memberCache map[string]struct{} repos []string filteredRepoCache *filteredRepoCache @@ -123,11 +123,11 @@ func (s *Source) JobID() sources.JobID { // filteredRepoCache is a wrapper around cache.Cache that filters out repos // based on include and exclude globs. type filteredRepoCache struct { - cache.Cache + cache.Cache[string] include, exclude []glob.Glob } -func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string) *filteredRepoCache { +func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache { includeGlobs := make([]glob.Glob, 0, len(include)) excludeGlobs := make([]glob.Glob, 0, len(exclude)) for _, ig := range include { @@ -209,13 +209,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so } s.conn = &conn - s.orgsCache = memory.New() + s.orgsCache = memory.New[string]() for _, org := range s.conn.Organizations { s.orgsCache.Set(org, org) } s.memberCache = make(map[string]struct{}) - s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), + s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](), append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...), s.conn.GetIgnoreRepos(), ) @@ -409,19 +409,13 @@ RepoLoop: for _, repo := range s.filteredRepoCache.Values() { repoCtx := context.WithValue(ctx, "repo", repo) - r, ok := repo.(string) - if !ok { - repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache") - continue - } - // Ensure that |s.repoInfoCache| contains an entry for |repo|. // This compensates for differences in enumeration logic between `--org` and `--repo`. // See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 - if _, ok := s.repoInfoCache.get(r); !ok { + if _, ok := s.repoInfoCache.get(repo); !ok { repoCtx.Logger().V(2).Info("Caching repository info") - _, urlParts, err := getRepoURLParts(r) + _, urlParts, err := getRepoURLParts(repo) if err != nil { repoCtx.Logger().Error(err, "Failed to parse repository URL") continue @@ -434,7 +428,7 @@ RepoLoop: gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID) // Normalize the URL to the Gist's pull URL. // See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937 - r = gist.GetGitPullURL() + repo = gist.GetGitPullURL() if s.handleRateLimit(err) { continue } @@ -461,7 +455,7 @@ RepoLoop: } } } - s.repos = append(s.repos, r) + s.repos = append(s.repos, repo) } githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos))) s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache)) diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 70575be13..2e777eb08 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -380,7 +380,7 @@ func TestEnumerateUnauthenticated(t *testing.T) { JSON([]map[string]string{{"full_name": "super-secret-org/super-secret-repo", "clone_url": "https://github.com/super-secret-org/super-secret-repo.git"}}) s := initTestSource(nil) - s.orgsCache = memory.New() + s.orgsCache = memory.New[string]() s.orgsCache.Set("super-secret-org", "super-secret-org") s.enumerateUnauthenticated(context.Background(), apiEndpoint) assert.Equal(t, 1, s.filteredRepoCache.Count())