refactor(cache): use generics (#2930)

This commit is contained in:
Richard Gomez 2024-06-06 13:08:00 -04:00 committed by GitHub
parent ea9f8ace9f
commit 5216142960
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 66 additions and 63 deletions

8
pkg/cache/cache.go vendored
View file

@ -2,11 +2,11 @@
package cache
// Cache is used to store key/value pairs.
type Cache interface {
type Cache[T any] interface {
// Set stores the given key/value pair.
Set(string, any)
Set(string, T)
// Get returns the value for the given key and a boolean indicating if the key was found.
Get(string) (any, bool)
Get(string) (T, bool)
// Exists returns true if the given key exists in the cache.
Exists(string) bool
// Delete the given key from the cache.
@ -18,7 +18,7 @@ type Cache interface {
// Keys returns all keys in the cache.
Keys() []string
// Values returns all values in the cache.
Values() []any
Values() []T
// Contents returns all keys in the cache encoded as a string.
Contents() string
}

View file

@ -14,46 +14,46 @@ const (
)
// Cache wraps the go-cache library to provide an in-memory key-value store.
type Cache struct {
type Cache[T any] struct {
c *cache.Cache
expiration time.Duration
purgeInterval time.Duration
}
// CacheOption defines a function type used for configuring a Cache.
type CacheOption func(*Cache)
type CacheOption[T any] func(*Cache[T])
// WithExpirationInterval returns a CacheOption to set the expiration interval of cache items.
// The interval determines the duration a cached item remains in the cache before it is expired.
func WithExpirationInterval(interval time.Duration) CacheOption {
return func(c *Cache) { c.expiration = interval }
func WithExpirationInterval[T any](interval time.Duration) CacheOption[T] {
return func(c *Cache[T]) { c.expiration = interval }
}
// WithPurgeInterval returns a CacheOption to set the interval at which the cache purges expired items.
// Regular purging helps in freeing up memory by removing stale entries.
func WithPurgeInterval(interval time.Duration) CacheOption {
return func(c *Cache) { c.purgeInterval = interval }
func WithPurgeInterval[T any](interval time.Duration) CacheOption[T] {
return func(c *Cache[T]) { c.purgeInterval = interval }
}
// New constructs a new in-memory cache instance with optional configurations.
// By default, it sets the expiration and purge intervals to 12 and 13 hours, respectively.
// These defaults can be overridden using the functional options: WithExpirationInterval and WithPurgeInterval.
func New(opts ...CacheOption) *Cache {
return NewWithData(nil, opts...)
func New[T any](opts ...CacheOption[T]) *Cache[T] {
return NewWithData[T](nil, opts...)
}
// CacheEntry represents a single entry in the cache, consisting of a key and its corresponding value.
type CacheEntry struct {
type CacheEntry[T any] struct {
// Key is the unique identifier for the entry.
Key string
// Value is the data stored in the entry.
Value any
Value T
}
// NewWithData constructs a new in-memory cache with existing data.
// It also accepts CacheOption parameters to override default configuration values.
func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache {
instance := &Cache{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval}
func NewWithData[T any](data []CacheEntry[T], opts ...CacheOption[T]) *Cache[T] {
instance := &Cache[T]{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval}
for _, opt := range opts {
opt(instance)
}
@ -69,38 +69,46 @@ func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache {
}
// Set adds a key-value pair to the cache.
func (c *Cache) Set(key string, value any) {
func (c *Cache[T]) Set(key string, value T) {
c.c.Set(key, value, defaultExpiration)
}
// Get returns the value for the given key.
func (c *Cache) Get(key string) (any, bool) {
return c.c.Get(key)
func (c *Cache[T]) Get(key string) (T, bool) {
var value T
v, ok := c.c.Get(key)
if !ok {
return value, false
}
value, ok = v.(T)
return value, ok
}
// Exists returns true if the given key exists in the cache.
func (c *Cache) Exists(key string) bool {
func (c *Cache[T]) Exists(key string) bool {
_, ok := c.c.Get(key)
return ok
}
// Delete removes the key-value pair from the cache.
func (c *Cache) Delete(key string) {
func (c *Cache[T]) Delete(key string) {
c.c.Delete(key)
}
// Clear removes all key-value pairs from the cache.
func (c *Cache) Clear() {
func (c *Cache[T]) Clear() {
c.c.Flush()
}
// Count returns the number of key-value pairs in the cache.
func (c *Cache) Count() int {
func (c *Cache[T]) Count() int {
return c.c.ItemCount()
}
// Keys returns all keys in the cache.
func (c *Cache) Keys() []string {
func (c *Cache[T]) Keys() []string {
items := c.c.Items()
res := make([]string, 0, len(items))
for k := range items {
@ -110,17 +118,20 @@ func (c *Cache) Keys() []string {
}
// Values returns all values in the cache.
func (c *Cache) Values() []any {
func (c *Cache[T]) Values() []T {
items := c.c.Items()
res := make([]any, 0, len(items))
res := make([]T, 0, len(items))
for _, v := range items {
res = append(res, v.Object)
obj, ok := v.Object.(T)
if ok {
res = append(res, obj)
}
}
return res
}
// Contents returns a comma-separated string containing all keys in the cache.
func (c *Cache) Contents() string {
func (c *Cache[T]) Contents() string {
items := c.c.Items()
res := make([]string, 0, len(items))
for k := range items {

View file

@ -10,7 +10,7 @@ import (
)
func TestCache(t *testing.T) {
c := New()
c := New[string]()
// Test set and get.
c.Set("key1", "key1")
@ -32,7 +32,7 @@ func TestCache(t *testing.T) {
// Test delete.
c.Delete("key1")
v, ok = c.Get("key1")
if ok || v != nil {
if ok || v != "" {
t.Fatalf("Unexpected value for key1 after delete: %v, %v", v, ok)
}
@ -40,7 +40,7 @@ func TestCache(t *testing.T) {
c.Set("key10", "key10")
c.Clear()
v, ok = c.Get("key10")
if ok || v != nil {
if ok || v != "" {
t.Fatalf("Unexpected value for key10 after clear: %v, %v", v, ok)
}
@ -59,9 +59,7 @@ func TestCache(t *testing.T) {
// Test getting only the values.
vals := make([]string, 0, c.Count())
for _, v := range c.Values() {
vals = append(vals, v.(string))
}
vals = append(vals, c.Values()...)
sort.Strings(vals)
sort.Strings(values)
if !cmp.Equal(values, vals) {
@ -83,7 +81,7 @@ func TestCache(t *testing.T) {
}
func TestCache_NewWithData(t *testing.T) {
data := []CacheEntry{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}
data := []CacheEntry[string]{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}
c := NewWithData(data)
// Test the count.
@ -106,10 +104,10 @@ func TestCache_NewWithData(t *testing.T) {
}
}
func setupBenchmarks(b *testing.B) *Cache {
func setupBenchmarks(b *testing.B) *Cache[string] {
b.Helper()
c := New()
c := New[string]()
for i := 0; i < 500_000; i++ {
key := fmt.Sprintf("key%d", i)
@ -120,7 +118,7 @@ func setupBenchmarks(b *testing.B) *Cache {
}
func BenchmarkSet(b *testing.B) {
c := New()
c := New[string]()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key%d", i)

View file

@ -81,11 +81,11 @@ type Source struct {
// at given increments.
type persistableCache struct {
persistIncrement int
cache.Cache
cache.Cache[string]
*sources.Progress
}
func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache {
func newPersistableCache(increment int, cache cache.Cache[string], p *sources.Progress) *persistableCache {
return &persistableCache{
persistIncrement: increment,
Cache: cache,
@ -95,7 +95,7 @@ func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress)
// Set overrides the cache Set method of the cache to enable the persistence
// of the cache contents the Progress of the source at given increments.
func (c *persistableCache) Set(key string, val any) {
func (c *persistableCache) Set(key string, val string) {
c.Cache.Set(key, val)
if ok, contents := c.shouldPersist(); ok {
c.Progress.EncodedResumeInfo = contents
@ -293,18 +293,18 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
}
func (s *Source) setupCache(ctx context.Context) *persistableCache {
var c cache.Cache
var c cache.Cache[string]
if s.Progress.EncodedResumeInfo != "" {
keys := strings.Split(s.Progress.EncodedResumeInfo, ",")
entries := make([]memory.CacheEntry, len(keys))
entries := make([]memory.CacheEntry[string], len(keys))
for i, val := range keys {
entries[i] = memory.CacheEntry{Key: val, Value: val}
entries[i] = memory.CacheEntry[string]{Key: val, Value: val}
}
c = memory.NewWithData(entries)
c = memory.NewWithData[string](entries)
ctx.Logger().V(3).Info("Loaded cache", "num_entries", len(entries))
} else {
c = memory.New()
c = memory.New[string]()
}
// TODO (ahrav): Make this configurable via conn.
@ -312,7 +312,7 @@ func (s *Source) setupCache(ctx context.Context) *persistableCache {
return persistCache
}
func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache) {
func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache[string]) {
s.mu.Lock()
defer s.mu.Unlock()

View file

@ -57,7 +57,7 @@ type Source struct {
sourceID sources.SourceID
jobID sources.JobID
verify bool
orgsCache cache.Cache
orgsCache cache.Cache[string]
memberCache map[string]struct{}
repos []string
filteredRepoCache *filteredRepoCache
@ -123,11 +123,11 @@ func (s *Source) JobID() sources.JobID {
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
// based on include and exclude globs.
type filteredRepoCache struct {
cache.Cache
cache.Cache[string]
include, exclude []glob.Glob
}
func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string) *filteredRepoCache {
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
includeGlobs := make([]glob.Glob, 0, len(include))
excludeGlobs := make([]glob.Glob, 0, len(exclude))
for _, ig := range include {
@ -209,13 +209,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
}
s.conn = &conn
s.orgsCache = memory.New()
s.orgsCache = memory.New[string]()
for _, org := range s.conn.Organizations {
s.orgsCache.Set(org, org)
}
s.memberCache = make(map[string]struct{})
s.filteredRepoCache = s.newFilteredRepoCache(memory.New(),
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
s.conn.GetIgnoreRepos(),
)
@ -409,19 +409,13 @@ RepoLoop:
for _, repo := range s.filteredRepoCache.Values() {
repoCtx := context.WithValue(ctx, "repo", repo)
r, ok := repo.(string)
if !ok {
repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache")
continue
}
// Ensure that |s.repoInfoCache| contains an entry for |repo|.
// This compensates for differences in enumeration logic between `--org` and `--repo`.
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
if _, ok := s.repoInfoCache.get(r); !ok {
if _, ok := s.repoInfoCache.get(repo); !ok {
repoCtx.Logger().V(2).Info("Caching repository info")
_, urlParts, err := getRepoURLParts(r)
_, urlParts, err := getRepoURLParts(repo)
if err != nil {
repoCtx.Logger().Error(err, "Failed to parse repository URL")
continue
@ -434,7 +428,7 @@ RepoLoop:
gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID)
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
r = gist.GetGitPullURL()
repo = gist.GetGitPullURL()
if s.handleRateLimit(err) {
continue
}
@ -461,7 +455,7 @@ RepoLoop:
}
}
}
s.repos = append(s.repos, r)
s.repos = append(s.repos, repo)
}
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))

View file

@ -380,7 +380,7 @@ func TestEnumerateUnauthenticated(t *testing.T) {
JSON([]map[string]string{{"full_name": "super-secret-org/super-secret-repo", "clone_url": "https://github.com/super-secret-org/super-secret-repo.git"}})
s := initTestSource(nil)
s.orgsCache = memory.New()
s.orgsCache = memory.New[string]()
s.orgsCache.Set("super-secret-org", "super-secret-org")
s.enumerateUnauthenticated(context.Background(), apiEndpoint)
assert.Equal(t, 1, s.filteredRepoCache.Count())