refactor(cache): use generics (#2930)

This commit is contained in:
Richard Gomez 2024-06-06 13:08:00 -04:00 committed by GitHub
parent ea9f8ace9f
commit 5216142960
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 66 additions and 63 deletions

8
pkg/cache/cache.go vendored
View file

@ -2,11 +2,11 @@
package cache package cache
// Cache is used to store key/value pairs. // Cache is used to store key/value pairs.
type Cache interface { type Cache[T any] interface {
// Set stores the given key/value pair. // Set stores the given key/value pair.
Set(string, any) Set(string, T)
// Get returns the value for the given key and a boolean indicating if the key was found. // Get returns the value for the given key and a boolean indicating if the key was found.
Get(string) (any, bool) Get(string) (T, bool)
// Exists returns true if the given key exists in the cache. // Exists returns true if the given key exists in the cache.
Exists(string) bool Exists(string) bool
// Delete the given key from the cache. // Delete the given key from the cache.
@ -18,7 +18,7 @@ type Cache interface {
// Keys returns all keys in the cache. // Keys returns all keys in the cache.
Keys() []string Keys() []string
// Values returns all values in the cache. // Values returns all values in the cache.
Values() []any Values() []T
// Contents returns all keys in the cache encoded as a string. // Contents returns all keys in the cache encoded as a string.
Contents() string Contents() string
} }

View file

@ -14,46 +14,46 @@ const (
) )
// Cache wraps the go-cache library to provide an in-memory key-value store. // Cache wraps the go-cache library to provide an in-memory key-value store.
type Cache struct { type Cache[T any] struct {
c *cache.Cache c *cache.Cache
expiration time.Duration expiration time.Duration
purgeInterval time.Duration purgeInterval time.Duration
} }
// CacheOption defines a function type used for configuring a Cache. // CacheOption defines a function type used for configuring a Cache.
type CacheOption func(*Cache) type CacheOption[T any] func(*Cache[T])
// WithExpirationInterval returns a CacheOption to set the expiration interval of cache items. // WithExpirationInterval returns a CacheOption to set the expiration interval of cache items.
// The interval determines the duration a cached item remains in the cache before it is expired. // The interval determines the duration a cached item remains in the cache before it is expired.
func WithExpirationInterval(interval time.Duration) CacheOption { func WithExpirationInterval[T any](interval time.Duration) CacheOption[T] {
return func(c *Cache) { c.expiration = interval } return func(c *Cache[T]) { c.expiration = interval }
} }
// WithPurgeInterval returns a CacheOption to set the interval at which the cache purges expired items. // WithPurgeInterval returns a CacheOption to set the interval at which the cache purges expired items.
// Regular purging helps in freeing up memory by removing stale entries. // Regular purging helps in freeing up memory by removing stale entries.
func WithPurgeInterval(interval time.Duration) CacheOption { func WithPurgeInterval[T any](interval time.Duration) CacheOption[T] {
return func(c *Cache) { c.purgeInterval = interval } return func(c *Cache[T]) { c.purgeInterval = interval }
} }
// New constructs a new in-memory cache instance with optional configurations. // New constructs a new in-memory cache instance with optional configurations.
// By default, it sets the expiration and purge intervals to 12 and 13 hours, respectively. // By default, it sets the expiration and purge intervals to 12 and 13 hours, respectively.
// These defaults can be overridden using the functional options: WithExpirationInterval and WithPurgeInterval. // These defaults can be overridden using the functional options: WithExpirationInterval and WithPurgeInterval.
func New(opts ...CacheOption) *Cache { func New[T any](opts ...CacheOption[T]) *Cache[T] {
return NewWithData(nil, opts...) return NewWithData[T](nil, opts...)
} }
// CacheEntry represents a single entry in the cache, consisting of a key and its corresponding value. // CacheEntry represents a single entry in the cache, consisting of a key and its corresponding value.
type CacheEntry struct { type CacheEntry[T any] struct {
// Key is the unique identifier for the entry. // Key is the unique identifier for the entry.
Key string Key string
// Value is the data stored in the entry. // Value is the data stored in the entry.
Value any Value T
} }
// NewWithData constructs a new in-memory cache with existing data. // NewWithData constructs a new in-memory cache with existing data.
// It also accepts CacheOption parameters to override default configuration values. // It also accepts CacheOption parameters to override default configuration values.
func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache { func NewWithData[T any](data []CacheEntry[T], opts ...CacheOption[T]) *Cache[T] {
instance := &Cache{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval} instance := &Cache[T]{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval}
for _, opt := range opts { for _, opt := range opts {
opt(instance) opt(instance)
} }
@ -69,38 +69,46 @@ func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache {
} }
// Set adds a key-value pair to the cache. // Set adds a key-value pair to the cache.
func (c *Cache) Set(key string, value any) { func (c *Cache[T]) Set(key string, value T) {
c.c.Set(key, value, defaultExpiration) c.c.Set(key, value, defaultExpiration)
} }
// Get returns the value for the given key. // Get returns the value for the given key.
func (c *Cache) Get(key string) (any, bool) { func (c *Cache[T]) Get(key string) (T, bool) {
return c.c.Get(key) var value T
v, ok := c.c.Get(key)
if !ok {
return value, false
}
value, ok = v.(T)
return value, ok
} }
// Exists returns true if the given key exists in the cache. // Exists returns true if the given key exists in the cache.
func (c *Cache) Exists(key string) bool { func (c *Cache[T]) Exists(key string) bool {
_, ok := c.c.Get(key) _, ok := c.c.Get(key)
return ok return ok
} }
// Delete removes the key-value pair from the cache. // Delete removes the key-value pair from the cache.
func (c *Cache) Delete(key string) { func (c *Cache[T]) Delete(key string) {
c.c.Delete(key) c.c.Delete(key)
} }
// Clear removes all key-value pairs from the cache. // Clear removes all key-value pairs from the cache.
func (c *Cache) Clear() { func (c *Cache[T]) Clear() {
c.c.Flush() c.c.Flush()
} }
// Count returns the number of key-value pairs in the cache. // Count returns the number of key-value pairs in the cache.
func (c *Cache) Count() int { func (c *Cache[T]) Count() int {
return c.c.ItemCount() return c.c.ItemCount()
} }
// Keys returns all keys in the cache. // Keys returns all keys in the cache.
func (c *Cache) Keys() []string { func (c *Cache[T]) Keys() []string {
items := c.c.Items() items := c.c.Items()
res := make([]string, 0, len(items)) res := make([]string, 0, len(items))
for k := range items { for k := range items {
@ -110,17 +118,20 @@ func (c *Cache) Keys() []string {
} }
// Values returns all values in the cache. // Values returns all values in the cache.
func (c *Cache) Values() []any { func (c *Cache[T]) Values() []T {
items := c.c.Items() items := c.c.Items()
res := make([]any, 0, len(items)) res := make([]T, 0, len(items))
for _, v := range items { for _, v := range items {
res = append(res, v.Object) obj, ok := v.Object.(T)
if ok {
res = append(res, obj)
}
} }
return res return res
} }
// Contents returns a comma-separated string containing all keys in the cache. // Contents returns a comma-separated string containing all keys in the cache.
func (c *Cache) Contents() string { func (c *Cache[T]) Contents() string {
items := c.c.Items() items := c.c.Items()
res := make([]string, 0, len(items)) res := make([]string, 0, len(items))
for k := range items { for k := range items {

View file

@ -10,7 +10,7 @@ import (
) )
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
c := New() c := New[string]()
// Test set and get. // Test set and get.
c.Set("key1", "key1") c.Set("key1", "key1")
@ -32,7 +32,7 @@ func TestCache(t *testing.T) {
// Test delete. // Test delete.
c.Delete("key1") c.Delete("key1")
v, ok = c.Get("key1") v, ok = c.Get("key1")
if ok || v != nil { if ok || v != "" {
t.Fatalf("Unexpected value for key1 after delete: %v, %v", v, ok) t.Fatalf("Unexpected value for key1 after delete: %v, %v", v, ok)
} }
@ -40,7 +40,7 @@ func TestCache(t *testing.T) {
c.Set("key10", "key10") c.Set("key10", "key10")
c.Clear() c.Clear()
v, ok = c.Get("key10") v, ok = c.Get("key10")
if ok || v != nil { if ok || v != "" {
t.Fatalf("Unexpected value for key10 after clear: %v, %v", v, ok) t.Fatalf("Unexpected value for key10 after clear: %v, %v", v, ok)
} }
@ -59,9 +59,7 @@ func TestCache(t *testing.T) {
// Test getting only the values. // Test getting only the values.
vals := make([]string, 0, c.Count()) vals := make([]string, 0, c.Count())
for _, v := range c.Values() { vals = append(vals, c.Values()...)
vals = append(vals, v.(string))
}
sort.Strings(vals) sort.Strings(vals)
sort.Strings(values) sort.Strings(values)
if !cmp.Equal(values, vals) { if !cmp.Equal(values, vals) {
@ -83,7 +81,7 @@ func TestCache(t *testing.T) {
} }
func TestCache_NewWithData(t *testing.T) { func TestCache_NewWithData(t *testing.T) {
data := []CacheEntry{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}} data := []CacheEntry[string]{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}
c := NewWithData(data) c := NewWithData(data)
// Test the count. // Test the count.
@ -106,10 +104,10 @@ func TestCache_NewWithData(t *testing.T) {
} }
} }
func setupBenchmarks(b *testing.B) *Cache { func setupBenchmarks(b *testing.B) *Cache[string] {
b.Helper() b.Helper()
c := New() c := New[string]()
for i := 0; i < 500_000; i++ { for i := 0; i < 500_000; i++ {
key := fmt.Sprintf("key%d", i) key := fmt.Sprintf("key%d", i)
@ -120,7 +118,7 @@ func setupBenchmarks(b *testing.B) *Cache {
} }
func BenchmarkSet(b *testing.B) { func BenchmarkSet(b *testing.B) {
c := New() c := New[string]()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key%d", i) key := fmt.Sprintf("key%d", i)

View file

@ -81,11 +81,11 @@ type Source struct {
// at given increments. // at given increments.
type persistableCache struct { type persistableCache struct {
persistIncrement int persistIncrement int
cache.Cache cache.Cache[string]
*sources.Progress *sources.Progress
} }
func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache { func newPersistableCache(increment int, cache cache.Cache[string], p *sources.Progress) *persistableCache {
return &persistableCache{ return &persistableCache{
persistIncrement: increment, persistIncrement: increment,
Cache: cache, Cache: cache,
@ -95,7 +95,7 @@ func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress)
// Set overrides the cache Set method of the cache to enable the persistence // Set overrides the cache Set method of the cache to enable the persistence
// of the cache contents the Progress of the source at given increments. // of the cache contents the Progress of the source at given increments.
func (c *persistableCache) Set(key string, val any) { func (c *persistableCache) Set(key string, val string) {
c.Cache.Set(key, val) c.Cache.Set(key, val)
if ok, contents := c.shouldPersist(); ok { if ok, contents := c.shouldPersist(); ok {
c.Progress.EncodedResumeInfo = contents c.Progress.EncodedResumeInfo = contents
@ -293,18 +293,18 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
} }
func (s *Source) setupCache(ctx context.Context) *persistableCache { func (s *Source) setupCache(ctx context.Context) *persistableCache {
var c cache.Cache var c cache.Cache[string]
if s.Progress.EncodedResumeInfo != "" { if s.Progress.EncodedResumeInfo != "" {
keys := strings.Split(s.Progress.EncodedResumeInfo, ",") keys := strings.Split(s.Progress.EncodedResumeInfo, ",")
entries := make([]memory.CacheEntry, len(keys)) entries := make([]memory.CacheEntry[string], len(keys))
for i, val := range keys { for i, val := range keys {
entries[i] = memory.CacheEntry{Key: val, Value: val} entries[i] = memory.CacheEntry[string]{Key: val, Value: val}
} }
c = memory.NewWithData(entries) c = memory.NewWithData[string](entries)
ctx.Logger().V(3).Info("Loaded cache", "num_entries", len(entries)) ctx.Logger().V(3).Info("Loaded cache", "num_entries", len(entries))
} else { } else {
c = memory.New() c = memory.New[string]()
} }
// TODO (ahrav): Make this configurable via conn. // TODO (ahrav): Make this configurable via conn.
@ -312,7 +312,7 @@ func (s *Source) setupCache(ctx context.Context) *persistableCache {
return persistCache return persistCache
} }
func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache) { func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache[string]) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()

View file

@ -57,7 +57,7 @@ type Source struct {
sourceID sources.SourceID sourceID sources.SourceID
jobID sources.JobID jobID sources.JobID
verify bool verify bool
orgsCache cache.Cache orgsCache cache.Cache[string]
memberCache map[string]struct{} memberCache map[string]struct{}
repos []string repos []string
filteredRepoCache *filteredRepoCache filteredRepoCache *filteredRepoCache
@ -123,11 +123,11 @@ func (s *Source) JobID() sources.JobID {
// filteredRepoCache is a wrapper around cache.Cache that filters out repos // filteredRepoCache is a wrapper around cache.Cache that filters out repos
// based on include and exclude globs. // based on include and exclude globs.
type filteredRepoCache struct { type filteredRepoCache struct {
cache.Cache cache.Cache[string]
include, exclude []glob.Glob include, exclude []glob.Glob
} }
func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string) *filteredRepoCache { func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
includeGlobs := make([]glob.Glob, 0, len(include)) includeGlobs := make([]glob.Glob, 0, len(include))
excludeGlobs := make([]glob.Glob, 0, len(exclude)) excludeGlobs := make([]glob.Glob, 0, len(exclude))
for _, ig := range include { for _, ig := range include {
@ -209,13 +209,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
} }
s.conn = &conn s.conn = &conn
s.orgsCache = memory.New() s.orgsCache = memory.New[string]()
for _, org := range s.conn.Organizations { for _, org := range s.conn.Organizations {
s.orgsCache.Set(org, org) s.orgsCache.Set(org, org)
} }
s.memberCache = make(map[string]struct{}) s.memberCache = make(map[string]struct{})
s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...), append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
s.conn.GetIgnoreRepos(), s.conn.GetIgnoreRepos(),
) )
@ -409,19 +409,13 @@ RepoLoop:
for _, repo := range s.filteredRepoCache.Values() { for _, repo := range s.filteredRepoCache.Values() {
repoCtx := context.WithValue(ctx, "repo", repo) repoCtx := context.WithValue(ctx, "repo", repo)
r, ok := repo.(string)
if !ok {
repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache")
continue
}
// Ensure that |s.repoInfoCache| contains an entry for |repo|. // Ensure that |s.repoInfoCache| contains an entry for |repo|.
// This compensates for differences in enumeration logic between `--org` and `--repo`. // This compensates for differences in enumeration logic between `--org` and `--repo`.
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 // See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
if _, ok := s.repoInfoCache.get(r); !ok { if _, ok := s.repoInfoCache.get(repo); !ok {
repoCtx.Logger().V(2).Info("Caching repository info") repoCtx.Logger().V(2).Info("Caching repository info")
_, urlParts, err := getRepoURLParts(r) _, urlParts, err := getRepoURLParts(repo)
if err != nil { if err != nil {
repoCtx.Logger().Error(err, "Failed to parse repository URL") repoCtx.Logger().Error(err, "Failed to parse repository URL")
continue continue
@ -434,7 +428,7 @@ RepoLoop:
gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID) gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID)
// Normalize the URL to the Gist's pull URL. // Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937 // See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
r = gist.GetGitPullURL() repo = gist.GetGitPullURL()
if s.handleRateLimit(err) { if s.handleRateLimit(err) {
continue continue
} }
@ -461,7 +455,7 @@ RepoLoop:
} }
} }
} }
s.repos = append(s.repos, r) s.repos = append(s.repos, repo)
} }
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos))) githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache)) s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))

View file

@ -380,7 +380,7 @@ func TestEnumerateUnauthenticated(t *testing.T) {
JSON([]map[string]string{{"full_name": "super-secret-org/super-secret-repo", "clone_url": "https://github.com/super-secret-org/super-secret-repo.git"}}) JSON([]map[string]string{{"full_name": "super-secret-org/super-secret-repo", "clone_url": "https://github.com/super-secret-org/super-secret-repo.git"}})
s := initTestSource(nil) s := initTestSource(nil)
s.orgsCache = memory.New() s.orgsCache = memory.New[string]()
s.orgsCache.Set("super-secret-org", "super-secret-org") s.orgsCache.Set("super-secret-org", "super-secret-org")
s.enumerateUnauthenticated(context.Background(), apiEndpoint) s.enumerateUnauthenticated(context.Background(), apiEndpoint)
assert.Equal(t, 1, s.filteredRepoCache.Count()) assert.Equal(t, 1, s.filteredRepoCache.Count())