mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
refactor(cache): use generics (#2930)
This commit is contained in:
parent
ea9f8ace9f
commit
5216142960
6 changed files with 66 additions and 63 deletions
8
pkg/cache/cache.go
vendored
8
pkg/cache/cache.go
vendored
|
@ -2,11 +2,11 @@
|
|||
package cache
|
||||
|
||||
// Cache is used to store key/value pairs.
|
||||
type Cache interface {
|
||||
type Cache[T any] interface {
|
||||
// Set stores the given key/value pair.
|
||||
Set(string, any)
|
||||
Set(string, T)
|
||||
// Get returns the value for the given key and a boolean indicating if the key was found.
|
||||
Get(string) (any, bool)
|
||||
Get(string) (T, bool)
|
||||
// Exists returns true if the given key exists in the cache.
|
||||
Exists(string) bool
|
||||
// Delete the given key from the cache.
|
||||
|
@ -18,7 +18,7 @@ type Cache interface {
|
|||
// Keys returns all keys in the cache.
|
||||
Keys() []string
|
||||
// Values returns all values in the cache.
|
||||
Values() []any
|
||||
Values() []T
|
||||
// Contents returns all keys in the cache encoded as a string.
|
||||
Contents() string
|
||||
}
|
||||
|
|
59
pkg/cache/memory/memory.go
vendored
59
pkg/cache/memory/memory.go
vendored
|
@ -14,46 +14,46 @@ const (
|
|||
)
|
||||
|
||||
// Cache wraps the go-cache library to provide an in-memory key-value store.
|
||||
type Cache struct {
|
||||
type Cache[T any] struct {
|
||||
c *cache.Cache
|
||||
expiration time.Duration
|
||||
purgeInterval time.Duration
|
||||
}
|
||||
|
||||
// CacheOption defines a function type used for configuring a Cache.
|
||||
type CacheOption func(*Cache)
|
||||
type CacheOption[T any] func(*Cache[T])
|
||||
|
||||
// WithExpirationInterval returns a CacheOption to set the expiration interval of cache items.
|
||||
// The interval determines the duration a cached item remains in the cache before it is expired.
|
||||
func WithExpirationInterval(interval time.Duration) CacheOption {
|
||||
return func(c *Cache) { c.expiration = interval }
|
||||
func WithExpirationInterval[T any](interval time.Duration) CacheOption[T] {
|
||||
return func(c *Cache[T]) { c.expiration = interval }
|
||||
}
|
||||
|
||||
// WithPurgeInterval returns a CacheOption to set the interval at which the cache purges expired items.
|
||||
// Regular purging helps in freeing up memory by removing stale entries.
|
||||
func WithPurgeInterval(interval time.Duration) CacheOption {
|
||||
return func(c *Cache) { c.purgeInterval = interval }
|
||||
func WithPurgeInterval[T any](interval time.Duration) CacheOption[T] {
|
||||
return func(c *Cache[T]) { c.purgeInterval = interval }
|
||||
}
|
||||
|
||||
// New constructs a new in-memory cache instance with optional configurations.
|
||||
// By default, it sets the expiration and purge intervals to 12 and 13 hours, respectively.
|
||||
// These defaults can be overridden using the functional options: WithExpirationInterval and WithPurgeInterval.
|
||||
func New(opts ...CacheOption) *Cache {
|
||||
return NewWithData(nil, opts...)
|
||||
func New[T any](opts ...CacheOption[T]) *Cache[T] {
|
||||
return NewWithData[T](nil, opts...)
|
||||
}
|
||||
|
||||
// CacheEntry represents a single entry in the cache, consisting of a key and its corresponding value.
|
||||
type CacheEntry struct {
|
||||
type CacheEntry[T any] struct {
|
||||
// Key is the unique identifier for the entry.
|
||||
Key string
|
||||
// Value is the data stored in the entry.
|
||||
Value any
|
||||
Value T
|
||||
}
|
||||
|
||||
// NewWithData constructs a new in-memory cache with existing data.
|
||||
// It also accepts CacheOption parameters to override default configuration values.
|
||||
func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache {
|
||||
instance := &Cache{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval}
|
||||
func NewWithData[T any](data []CacheEntry[T], opts ...CacheOption[T]) *Cache[T] {
|
||||
instance := &Cache[T]{expiration: defaultExpirationInterval, purgeInterval: defaultPurgeInterval}
|
||||
for _, opt := range opts {
|
||||
opt(instance)
|
||||
}
|
||||
|
@ -69,38 +69,46 @@ func NewWithData(data []CacheEntry, opts ...CacheOption) *Cache {
|
|||
}
|
||||
|
||||
// Set adds a key-value pair to the cache.
|
||||
func (c *Cache) Set(key string, value any) {
|
||||
func (c *Cache[T]) Set(key string, value T) {
|
||||
c.c.Set(key, value, defaultExpiration)
|
||||
}
|
||||
|
||||
// Get returns the value for the given key.
|
||||
func (c *Cache) Get(key string) (any, bool) {
|
||||
return c.c.Get(key)
|
||||
func (c *Cache[T]) Get(key string) (T, bool) {
|
||||
var value T
|
||||
|
||||
v, ok := c.c.Get(key)
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
value, ok = v.(T)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// Exists returns true if the given key exists in the cache.
|
||||
func (c *Cache) Exists(key string) bool {
|
||||
func (c *Cache[T]) Exists(key string) bool {
|
||||
_, ok := c.c.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Delete removes the key-value pair from the cache.
|
||||
func (c *Cache) Delete(key string) {
|
||||
func (c *Cache[T]) Delete(key string) {
|
||||
c.c.Delete(key)
|
||||
}
|
||||
|
||||
// Clear removes all key-value pairs from the cache.
|
||||
func (c *Cache) Clear() {
|
||||
func (c *Cache[T]) Clear() {
|
||||
c.c.Flush()
|
||||
}
|
||||
|
||||
// Count returns the number of key-value pairs in the cache.
|
||||
func (c *Cache) Count() int {
|
||||
func (c *Cache[T]) Count() int {
|
||||
return c.c.ItemCount()
|
||||
}
|
||||
|
||||
// Keys returns all keys in the cache.
|
||||
func (c *Cache) Keys() []string {
|
||||
func (c *Cache[T]) Keys() []string {
|
||||
items := c.c.Items()
|
||||
res := make([]string, 0, len(items))
|
||||
for k := range items {
|
||||
|
@ -110,17 +118,20 @@ func (c *Cache) Keys() []string {
|
|||
}
|
||||
|
||||
// Values returns all values in the cache.
|
||||
func (c *Cache) Values() []any {
|
||||
func (c *Cache[T]) Values() []T {
|
||||
items := c.c.Items()
|
||||
res := make([]any, 0, len(items))
|
||||
res := make([]T, 0, len(items))
|
||||
for _, v := range items {
|
||||
res = append(res, v.Object)
|
||||
obj, ok := v.Object.(T)
|
||||
if ok {
|
||||
res = append(res, obj)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Contents returns a comma-separated string containing all keys in the cache.
|
||||
func (c *Cache) Contents() string {
|
||||
func (c *Cache[T]) Contents() string {
|
||||
items := c.c.Items()
|
||||
res := make([]string, 0, len(items))
|
||||
for k := range items {
|
||||
|
|
18
pkg/cache/memory/memory_test.go
vendored
18
pkg/cache/memory/memory_test.go
vendored
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
c := New()
|
||||
c := New[string]()
|
||||
|
||||
// Test set and get.
|
||||
c.Set("key1", "key1")
|
||||
|
@ -32,7 +32,7 @@ func TestCache(t *testing.T) {
|
|||
// Test delete.
|
||||
c.Delete("key1")
|
||||
v, ok = c.Get("key1")
|
||||
if ok || v != nil {
|
||||
if ok || v != "" {
|
||||
t.Fatalf("Unexpected value for key1 after delete: %v, %v", v, ok)
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ func TestCache(t *testing.T) {
|
|||
c.Set("key10", "key10")
|
||||
c.Clear()
|
||||
v, ok = c.Get("key10")
|
||||
if ok || v != nil {
|
||||
if ok || v != "" {
|
||||
t.Fatalf("Unexpected value for key10 after clear: %v, %v", v, ok)
|
||||
}
|
||||
|
||||
|
@ -59,9 +59,7 @@ func TestCache(t *testing.T) {
|
|||
|
||||
// Test getting only the values.
|
||||
vals := make([]string, 0, c.Count())
|
||||
for _, v := range c.Values() {
|
||||
vals = append(vals, v.(string))
|
||||
}
|
||||
vals = append(vals, c.Values()...)
|
||||
sort.Strings(vals)
|
||||
sort.Strings(values)
|
||||
if !cmp.Equal(values, vals) {
|
||||
|
@ -83,7 +81,7 @@ func TestCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCache_NewWithData(t *testing.T) {
|
||||
data := []CacheEntry{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}
|
||||
data := []CacheEntry[string]{{"key1", "value1"}, {"key2", "value2"}, {"key3", "value3"}}
|
||||
c := NewWithData(data)
|
||||
|
||||
// Test the count.
|
||||
|
@ -106,10 +104,10 @@ func TestCache_NewWithData(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func setupBenchmarks(b *testing.B) *Cache {
|
||||
func setupBenchmarks(b *testing.B) *Cache[string] {
|
||||
b.Helper()
|
||||
|
||||
c := New()
|
||||
c := New[string]()
|
||||
|
||||
for i := 0; i < 500_000; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
|
@ -120,7 +118,7 @@ func setupBenchmarks(b *testing.B) *Cache {
|
|||
}
|
||||
|
||||
func BenchmarkSet(b *testing.B) {
|
||||
c := New()
|
||||
c := New[string]()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
|
|
|
@ -81,11 +81,11 @@ type Source struct {
|
|||
// at given increments.
|
||||
type persistableCache struct {
|
||||
persistIncrement int
|
||||
cache.Cache
|
||||
cache.Cache[string]
|
||||
*sources.Progress
|
||||
}
|
||||
|
||||
func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache {
|
||||
func newPersistableCache(increment int, cache cache.Cache[string], p *sources.Progress) *persistableCache {
|
||||
return &persistableCache{
|
||||
persistIncrement: increment,
|
||||
Cache: cache,
|
||||
|
@ -95,7 +95,7 @@ func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress)
|
|||
|
||||
// Set overrides the cache Set method of the cache to enable the persistence
|
||||
// of the cache contents the Progress of the source at given increments.
|
||||
func (c *persistableCache) Set(key string, val any) {
|
||||
func (c *persistableCache) Set(key string, val string) {
|
||||
c.Cache.Set(key, val)
|
||||
if ok, contents := c.shouldPersist(); ok {
|
||||
c.Progress.EncodedResumeInfo = contents
|
||||
|
@ -293,18 +293,18 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
|
|||
}
|
||||
|
||||
func (s *Source) setupCache(ctx context.Context) *persistableCache {
|
||||
var c cache.Cache
|
||||
var c cache.Cache[string]
|
||||
if s.Progress.EncodedResumeInfo != "" {
|
||||
keys := strings.Split(s.Progress.EncodedResumeInfo, ",")
|
||||
entries := make([]memory.CacheEntry, len(keys))
|
||||
entries := make([]memory.CacheEntry[string], len(keys))
|
||||
for i, val := range keys {
|
||||
entries[i] = memory.CacheEntry{Key: val, Value: val}
|
||||
entries[i] = memory.CacheEntry[string]{Key: val, Value: val}
|
||||
}
|
||||
|
||||
c = memory.NewWithData(entries)
|
||||
c = memory.NewWithData[string](entries)
|
||||
ctx.Logger().V(3).Info("Loaded cache", "num_entries", len(entries))
|
||||
} else {
|
||||
c = memory.New()
|
||||
c = memory.New[string]()
|
||||
}
|
||||
|
||||
// TODO (ahrav): Make this configurable via conn.
|
||||
|
@ -312,7 +312,7 @@ func (s *Source) setupCache(ctx context.Context) *persistableCache {
|
|||
return persistCache
|
||||
}
|
||||
|
||||
func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache) {
|
||||
func (s *Source) setProgress(ctx context.Context, md5, objName string, cache cache.Cache[string]) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ type Source struct {
|
|||
sourceID sources.SourceID
|
||||
jobID sources.JobID
|
||||
verify bool
|
||||
orgsCache cache.Cache
|
||||
orgsCache cache.Cache[string]
|
||||
memberCache map[string]struct{}
|
||||
repos []string
|
||||
filteredRepoCache *filteredRepoCache
|
||||
|
@ -123,11 +123,11 @@ func (s *Source) JobID() sources.JobID {
|
|||
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
|
||||
// based on include and exclude globs.
|
||||
type filteredRepoCache struct {
|
||||
cache.Cache
|
||||
cache.Cache[string]
|
||||
include, exclude []glob.Glob
|
||||
}
|
||||
|
||||
func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string) *filteredRepoCache {
|
||||
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
|
||||
includeGlobs := make([]glob.Glob, 0, len(include))
|
||||
excludeGlobs := make([]glob.Glob, 0, len(exclude))
|
||||
for _, ig := range include {
|
||||
|
@ -209,13 +209,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
|
|||
}
|
||||
s.conn = &conn
|
||||
|
||||
s.orgsCache = memory.New()
|
||||
s.orgsCache = memory.New[string]()
|
||||
for _, org := range s.conn.Organizations {
|
||||
s.orgsCache.Set(org, org)
|
||||
}
|
||||
s.memberCache = make(map[string]struct{})
|
||||
|
||||
s.filteredRepoCache = s.newFilteredRepoCache(memory.New(),
|
||||
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
|
||||
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
|
||||
s.conn.GetIgnoreRepos(),
|
||||
)
|
||||
|
@ -409,19 +409,13 @@ RepoLoop:
|
|||
for _, repo := range s.filteredRepoCache.Values() {
|
||||
repoCtx := context.WithValue(ctx, "repo", repo)
|
||||
|
||||
r, ok := repo.(string)
|
||||
if !ok {
|
||||
repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache")
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure that |s.repoInfoCache| contains an entry for |repo|.
|
||||
// This compensates for differences in enumeration logic between `--org` and `--repo`.
|
||||
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
|
||||
if _, ok := s.repoInfoCache.get(r); !ok {
|
||||
if _, ok := s.repoInfoCache.get(repo); !ok {
|
||||
repoCtx.Logger().V(2).Info("Caching repository info")
|
||||
|
||||
_, urlParts, err := getRepoURLParts(r)
|
||||
_, urlParts, err := getRepoURLParts(repo)
|
||||
if err != nil {
|
||||
repoCtx.Logger().Error(err, "Failed to parse repository URL")
|
||||
continue
|
||||
|
@ -434,7 +428,7 @@ RepoLoop:
|
|||
gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID)
|
||||
// Normalize the URL to the Gist's pull URL.
|
||||
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
|
||||
r = gist.GetGitPullURL()
|
||||
repo = gist.GetGitPullURL()
|
||||
if s.handleRateLimit(err) {
|
||||
continue
|
||||
}
|
||||
|
@ -461,7 +455,7 @@ RepoLoop:
|
|||
}
|
||||
}
|
||||
}
|
||||
s.repos = append(s.repos, r)
|
||||
s.repos = append(s.repos, repo)
|
||||
}
|
||||
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
|
||||
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
|
||||
|
|
|
@ -380,7 +380,7 @@ func TestEnumerateUnauthenticated(t *testing.T) {
|
|||
JSON([]map[string]string{{"full_name": "super-secret-org/super-secret-repo", "clone_url": "https://github.com/super-secret-org/super-secret-repo.git"}})
|
||||
|
||||
s := initTestSource(nil)
|
||||
s.orgsCache = memory.New()
|
||||
s.orgsCache = memory.New[string]()
|
||||
s.orgsCache.Set("super-secret-org", "super-secret-org")
|
||||
s.enumerateUnauthenticated(context.Background(), apiEndpoint)
|
||||
assert.Equal(t, 1, s.filteredRepoCache.Count())
|
||||
|
|
Loading…
Reference in a new issue