New Source: HuggingFace (#3000)

* initial spike on hf

* added in user and org enum

* adding huggingface source

* updated with lint suggestions

* updated readme

* addressing resources that require org approval to access

* removing unneeded code

* updating with new error msg for 403

* deleted unused code + added resource check in main
This commit is contained in:
joeleonjr 2024-06-27 13:22:06 -04:00 committed by GitHub
parent e9206c66bb
commit 01a1499600
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 4568 additions and 964 deletions

View file

@ -301,6 +301,27 @@ trufflehog elasticsearch \
--api-key 'MlVtVjBZ...ZSYlduYnF1djh3NG5FQQ=='
```
## 15. Scan HuggingFace
### Scan a HuggingFace Model, Dataset or Space
```bash
trufflehog huggingface --model <username/modelname> --space <username/spacename> --dataset <username/datasetname>
```
### Scan all Models, Datasets and Space belonging to a HuggingFace Org/User
```bash
trufflehog huggingface --org <orgname> --user <username>
```
Optionally, skip scanning a type of resource with `--skip-models`, `--skip-datasets`, `--skip-spaces` or a particular resource with `--ignore-models/datasets/spaces <resource-name>`.
### Scan Discussion and PR Comments
```bash
trufflehog huggingface --model <username/modelname> --include-discussions --include-prs
```
# :question: FAQ
- All I see is `🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷` and the program exits, what gives?

54
main.go
View file

@ -198,6 +198,27 @@ var (
jenkinsPassword = jenkinsScan.Flag("password", "Jenkins password").Envar("JENKINS_PASSWORD").String()
jenkinsInsecureSkipVerifyTLS = jenkinsScan.Flag("insecure-skip-verify-tls", "Skip TLS verification").Envar("JENKINS_INSECURE_SKIP_VERIFY_TLS").Bool()
huggingfaceScan = cli.Command("huggingface", "Find credentials in HuggingFace datasets, models and spaces.")
huggingfaceEndpoint = huggingfaceScan.Flag("endpoint", "HuggingFace endpoint.").Default("https://huggingface.co").String()
huggingfaceModels = huggingfaceScan.Flag("model", "HuggingFace model to scan. You can repeat this flag. Example: 'username/model'").Strings()
huggingfaceSpaces = huggingfaceScan.Flag("space", "HuggingFace space to scan. You can repeat this flag. Example: 'username/space'").Strings()
huggingfaceDatasets = huggingfaceScan.Flag("dataset", "HuggingFace dataset to scan. You can repeat this flag. Example: 'username/dataset'").Strings()
huggingfaceOrgs = huggingfaceScan.Flag("org", `HuggingFace organization to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
huggingfaceUsers = huggingfaceScan.Flag("user", `HuggingFace user to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
huggingfaceToken = huggingfaceScan.Flag("token", "HuggingFace token. Can be provided with environment variable HUGGINGFACE_TOKEN.").Envar("HUGGINGFACE_TOKEN").String()
huggingfaceIncludeModels = huggingfaceScan.Flag("include-models", "Models to include in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
huggingfaceIncludeSpaces = huggingfaceScan.Flag("include-spaces", "Spaces to include in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
huggingfaceIncludeDatasets = huggingfaceScan.Flag("include-datasets", "Datasets to include in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
huggingfaceIgnoreModels = huggingfaceScan.Flag("ignore-models", "Models to ignore in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
huggingfaceIgnoreSpaces = huggingfaceScan.Flag("ignore-spaces", "Spaces to ignore in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
huggingfaceIgnoreDatasets = huggingfaceScan.Flag("ignore-datasets", "Datasets to ignore in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
huggingfaceSkipAllModels = huggingfaceScan.Flag("skip-all-models", "Skip all model scans. (Only used with --user or --org)").Bool()
huggingfaceSkipAllSpaces = huggingfaceScan.Flag("skip-all-spaces", "Skip all space scans. (Only used with --user or --org)").Bool()
huggingfaceSkipAllDatasets = huggingfaceScan.Flag("skip-all-datasets", "Skip all dataset scans. (Only used with --user or --org)").Bool()
huggingfaceIncludeDiscussions = huggingfaceScan.Flag("include-discussions", "Include discussions in scan.").Bool()
huggingfaceIncludePrs = huggingfaceScan.Flag("include-prs", "Include pull requests in scan.").Bool()
usingTUI = false
)
@ -738,6 +759,39 @@ func runSingleScan(ctx context.Context, cmd string, cfg engine.Config) (metrics,
if err := eng.ScanJenkins(ctx, cfg); err != nil {
return scanMetrics, fmt.Errorf("failed to scan Jenkins: %v", err)
}
case huggingfaceScan.FullCommand():
if *huggingfaceEndpoint != "" {
*huggingfaceEndpoint = strings.TrimRight(*huggingfaceEndpoint, "/")
}
if len(*huggingfaceModels) == 0 && len(*huggingfaceSpaces) == 0 && len(*huggingfaceDatasets) == 0 && len(*huggingfaceOrgs) == 0 && len(*huggingfaceUsers) == 0 {
return scanMetrics, fmt.Errorf("invalid config: you must specify at least one organization, user, model, space or dataset")
}
cfg := engine.HuggingfaceConfig{
Endpoint: *huggingfaceEndpoint,
Models: *huggingfaceModels,
Spaces: *huggingfaceSpaces,
Datasets: *huggingfaceDatasets,
Organizations: *huggingfaceOrgs,
Users: *huggingfaceUsers,
Token: *huggingfaceToken,
IncludeModels: *huggingfaceIncludeModels,
IncludeSpaces: *huggingfaceIncludeSpaces,
IncludeDatasets: *huggingfaceIncludeDatasets,
IgnoreModels: *huggingfaceIgnoreModels,
IgnoreSpaces: *huggingfaceIgnoreSpaces,
IgnoreDatasets: *huggingfaceIgnoreDatasets,
SkipAllModels: *huggingfaceSkipAllModels,
SkipAllSpaces: *huggingfaceSkipAllSpaces,
SkipAllDatasets: *huggingfaceSkipAllDatasets,
IncludeDiscussions: *huggingfaceIncludeDiscussions,
IncludePrs: *huggingfaceIncludePrs,
Concurrency: *concurrency,
}
if err := eng.ScanHuggingface(ctx, cfg); err != nil {
return scanMetrics, fmt.Errorf("failed to scan HuggingFace: %v", err)
}
default:
return scanMetrics, fmt.Errorf("invalid command: %s", cmd)
}

80
pkg/engine/huggingface.go Normal file
View file

@ -0,0 +1,80 @@
package engine
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/huggingface"
)
// HuggingFaceConfig represents the configuration for HuggingFace.
type HuggingfaceConfig struct {
Endpoint string
Models []string
Spaces []string
Datasets []string
Organizations []string
Users []string
IncludeModels []string
IgnoreModels []string
IncludeSpaces []string
IgnoreSpaces []string
IncludeDatasets []string
IgnoreDatasets []string
SkipAllModels bool
SkipAllSpaces bool
SkipAllDatasets bool
IncludeDiscussions bool
IncludePrs bool
Token string
Concurrency int
}
// ScanGitHub scans HuggingFace with the provided options.
func (e *Engine) ScanHuggingface(ctx context.Context, c HuggingfaceConfig) error {
connection := sourcespb.Huggingface{
Endpoint: c.Endpoint,
Models: c.Models,
Spaces: c.Spaces,
Datasets: c.Datasets,
Organizations: c.Organizations,
Users: c.Users,
IncludeModels: c.IncludeModels,
IgnoreModels: c.IgnoreModels,
IncludeSpaces: c.IncludeSpaces,
IgnoreSpaces: c.IgnoreSpaces,
IncludeDatasets: c.IncludeDatasets,
IgnoreDatasets: c.IgnoreDatasets,
SkipAllModels: c.SkipAllModels,
SkipAllSpaces: c.SkipAllSpaces,
SkipAllDatasets: c.SkipAllDatasets,
IncludeDiscussions: c.IncludeDiscussions,
IncludePrs: c.IncludePrs,
}
if len(c.Token) > 0 {
connection.Credential = &sourcespb.Huggingface_Token{
Token: c.Token,
}
} else {
connection.Credential = &sourcespb.Huggingface_Unauthenticated{}
}
var conn anypb.Any
err := anypb.MarshalFrom(&conn, &connection, proto.MarshalOptions{})
if err != nil {
ctx.Logger().Error(err, "failed to marshal huggingface connection")
return err
}
sourceName := "trufflehog - huggingface"
sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE)
huggingfaceSource := &huggingface.Source{}
if err := huggingfaceSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
return err
}
_, err = e.sourceManager.Run(ctx, sourceName, huggingfaceSource)
return err
}

File diff suppressed because it is too large Load diff

View file

@ -1493,6 +1493,125 @@ var _ interface {
ErrorName() string
} = GCSValidationError{}
// Validate checks the field values on Huggingface with the rules defined in
// the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *Huggingface) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on Huggingface with the rules defined in
// the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
// nil if none found.
func (m *Huggingface) ValidateAll() error {
return m.validate(true)
}
func (m *Huggingface) validate(all bool) error {
if m == nil {
return nil
}
var errors []error
// no validation rules for Link
// no validation rules for Username
// no validation rules for Repository
// no validation rules for Commit
// no validation rules for Email
// no validation rules for File
// no validation rules for Timestamp
// no validation rules for Line
// no validation rules for Visibility
// no validation rules for ResourceType
if len(errors) > 0 {
return HuggingfaceMultiError(errors)
}
return nil
}
// HuggingfaceMultiError is an error wrapping multiple validation errors
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
type HuggingfaceMultiError []error
// Error returns a concatenation of all the error messages it wraps.
func (m HuggingfaceMultiError) Error() string {
var msgs []string
for _, err := range m {
msgs = append(msgs, err.Error())
}
return strings.Join(msgs, "; ")
}
// AllErrors returns a list of validation violation errors.
func (m HuggingfaceMultiError) AllErrors() []error { return m }
// HuggingfaceValidationError is the validation error returned by
// Huggingface.Validate if the designated constraints aren't met.
type HuggingfaceValidationError struct {
field string
reason string
cause error
key bool
}
// Field function returns field value.
func (e HuggingfaceValidationError) Field() string { return e.field }
// Reason function returns reason value.
func (e HuggingfaceValidationError) Reason() string { return e.reason }
// Cause function returns cause value.
func (e HuggingfaceValidationError) Cause() error { return e.cause }
// Key function returns key value.
func (e HuggingfaceValidationError) Key() bool { return e.key }
// ErrorName returns error name.
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
// Error satisfies the builtin error interface
func (e HuggingfaceValidationError) Error() string {
cause := ""
if e.cause != nil {
cause = fmt.Sprintf(" | caused by: %v", e.cause)
}
key := ""
if e.key {
key = "key for "
}
return fmt.Sprintf(
"invalid %sHuggingface.%s: %s%s",
key,
e.field,
e.reason,
cause)
}
var _ error = HuggingfaceValidationError{}
var _ interface {
Field() string
Reason() string
Key() bool
Cause() error
ErrorName() string
} = HuggingfaceValidationError{}
// Validate checks the field values on Jira with the rules defined in the proto
// definition for this message. If any rules are violated, the first error
// encountered is returned, or nil if there are no violations.
@ -5074,6 +5193,47 @@ func (m *MetaData) validate(all bool) error {
}
}
case *MetaData_Huggingface:
if v == nil {
err := MetaDataValidationError{
field: "Data",
reason: "oneof value cannot be a typed-nil",
}
if !all {
return err
}
errors = append(errors, err)
}
if all {
switch v := interface{}(m.GetHuggingface()).(type) {
case interface{ ValidateAll() error }:
if err := v.ValidateAll(); err != nil {
errors = append(errors, MetaDataValidationError{
field: "Huggingface",
reason: "embedded message failed validation",
cause: err,
})
}
case interface{ Validate() error }:
if err := v.Validate(); err != nil {
errors = append(errors, MetaDataValidationError{
field: "Huggingface",
reason: "embedded message failed validation",
cause: err,
})
}
}
} else if v, ok := interface{}(m.GetHuggingface()).(interface{ Validate() error }); ok {
if err := v.Validate(); err != nil {
return MetaDataValidationError{
field: "Huggingface",
reason: "embedded message failed validation",
cause: err,
}
}
}
default:
_ = v // ensures v is used
}

File diff suppressed because it is too large Load diff

View file

@ -2875,6 +2875,185 @@ var _ interface {
ErrorName() string
} = GoogleDriveValidationError{}
// Validate checks the field values on Huggingface with the rules defined in
// the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *Huggingface) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on Huggingface with the rules defined in
// the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
// nil if none found.
func (m *Huggingface) ValidateAll() error {
return m.validate(true)
}
func (m *Huggingface) validate(all bool) error {
if m == nil {
return nil
}
var errors []error
if _, err := url.Parse(m.GetEndpoint()); err != nil {
err = HuggingfaceValidationError{
field: "Endpoint",
reason: "value must be a valid URI",
cause: err,
}
if !all {
return err
}
errors = append(errors, err)
}
// no validation rules for SkipAllModels
// no validation rules for SkipAllSpaces
// no validation rules for SkipAllDatasets
// no validation rules for IncludeDiscussions
// no validation rules for IncludePrs
switch v := m.Credential.(type) {
case *Huggingface_Token:
if v == nil {
err := HuggingfaceValidationError{
field: "Credential",
reason: "oneof value cannot be a typed-nil",
}
if !all {
return err
}
errors = append(errors, err)
}
// no validation rules for Token
case *Huggingface_Unauthenticated:
if v == nil {
err := HuggingfaceValidationError{
field: "Credential",
reason: "oneof value cannot be a typed-nil",
}
if !all {
return err
}
errors = append(errors, err)
}
if all {
switch v := interface{}(m.GetUnauthenticated()).(type) {
case interface{ ValidateAll() error }:
if err := v.ValidateAll(); err != nil {
errors = append(errors, HuggingfaceValidationError{
field: "Unauthenticated",
reason: "embedded message failed validation",
cause: err,
})
}
case interface{ Validate() error }:
if err := v.Validate(); err != nil {
errors = append(errors, HuggingfaceValidationError{
field: "Unauthenticated",
reason: "embedded message failed validation",
cause: err,
})
}
}
} else if v, ok := interface{}(m.GetUnauthenticated()).(interface{ Validate() error }); ok {
if err := v.Validate(); err != nil {
return HuggingfaceValidationError{
field: "Unauthenticated",
reason: "embedded message failed validation",
cause: err,
}
}
}
default:
_ = v // ensures v is used
}
if len(errors) > 0 {
return HuggingfaceMultiError(errors)
}
return nil
}
// HuggingfaceMultiError is an error wrapping multiple validation errors
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
type HuggingfaceMultiError []error
// Error returns a concatenation of all the error messages it wraps.
func (m HuggingfaceMultiError) Error() string {
var msgs []string
for _, err := range m {
msgs = append(msgs, err.Error())
}
return strings.Join(msgs, "; ")
}
// AllErrors returns a list of validation violation errors.
func (m HuggingfaceMultiError) AllErrors() []error { return m }
// HuggingfaceValidationError is the validation error returned by
// Huggingface.Validate if the designated constraints aren't met.
type HuggingfaceValidationError struct {
field string
reason string
cause error
key bool
}
// Field function returns field value.
func (e HuggingfaceValidationError) Field() string { return e.field }
// Reason function returns reason value.
func (e HuggingfaceValidationError) Reason() string { return e.reason }
// Cause function returns cause value.
func (e HuggingfaceValidationError) Cause() error { return e.cause }
// Key function returns key value.
func (e HuggingfaceValidationError) Key() bool { return e.key }
// ErrorName returns error name.
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
// Error satisfies the builtin error interface
func (e HuggingfaceValidationError) Error() string {
cause := ""
if e.cause != nil {
cause = fmt.Sprintf(" | caused by: %v", e.cause)
}
key := ""
if e.key {
key = "key for "
}
return fmt.Sprintf(
"invalid %sHuggingface.%s: %s%s",
key,
e.field,
e.reason,
cause)
}
var _ error = HuggingfaceValidationError{}
var _ interface {
Field() string
Reason() string
Key() bool
Cause() error
ErrorName() string
} = HuggingfaceValidationError{}
// Validate checks the field values on JIRA with the rules defined in the proto
// definition for this message. If any rules are violated, the first error
// encountered is returned, or nil if there are no violations.

View file

@ -0,0 +1,223 @@
package huggingface
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
// Maps for API and HTML paths
var apiPaths = map[string]string{
DATASET: DatasetsRoute,
MODEL: ModelsAPIRoute,
SPACE: SpacesRoute,
}
var htmlPaths = map[string]string{
DATASET: DatasetsRoute,
MODEL: "",
SPACE: SpacesRoute,
}
type Author struct {
Username string `json:"name"`
}
type Latest struct {
Raw string `json:"raw"`
}
type Data struct {
Latest Latest `json:"latest"`
}
type Event struct {
Type string `json:"type"`
Author Author `json:"author"`
CreatedAt string `json:"createdAt"`
Data Data `json:"data"`
ID string `json:"id"`
}
func (e Event) GetAuthor() string {
return e.Author.Username
}
func (e Event) GetCreatedAt() string {
return e.CreatedAt
}
func (e Event) GetID() string {
return e.ID
}
type RepoData struct {
FullName string `json:"name"`
ResourceType string `json:"type"`
}
type Discussion struct {
ID int `json:"num"`
IsPR bool `json:"isPullRequest"`
CreatedAt string `json:"createdAt"`
Title string `json:"title"`
Events []Event `json:"events"`
Repo RepoData `json:"repo"`
}
func (d Discussion) GetID() string {
return fmt.Sprint(d.ID)
}
func (d Discussion) GetTitle() string {
return d.Title
}
func (d Discussion) GetCreatedAt() string {
return d.CreatedAt
}
func (d Discussion) GetRepo() string {
return d.Repo.FullName
}
// GetDiscussionPath returns the path (ex: "/models/user/repo/discussions/1") for the discussion
func (d Discussion) GetDiscussionPath() string {
basePath := fmt.Sprintf("%s/%s/%s", d.GetRepo(), DiscussionsRoute, d.GetID())
if d.Repo.ResourceType == "model" {
return basePath
}
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
}
// GetGitPath returns the path (ex: "/models/user/repo.git") for the repo's git directory
func (d Discussion) GetGitPath() string {
basePath := fmt.Sprintf("%s.git", d.GetRepo())
if d.Repo.ResourceType == "model" {
return basePath
}
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
}
type DiscussionList struct {
Discussions []Discussion `json:"discussions"`
}
type Repo struct {
IsPrivate bool `json:"private"`
Owner string `json:"author"`
RepoID string `json:"id"`
}
type HFClient struct {
BaseURL string
APIKey string
HTTPClient *http.Client
}
// NewClient creates a new API client
func NewHFClient(baseURL, apiKey string, timeout time.Duration) *HFClient {
return &HFClient{
BaseURL: baseURL,
APIKey: apiKey,
HTTPClient: &http.Client{
Timeout: timeout,
},
}
}
// get makes a GET request to the Hugging Face API
// Note: not addressing rate limit, since it seems very permissive. (ex: "If \
// your account suddenly sends 10k requests then youre likely to receive 503")
func (c *HFClient) get(ctx context.Context, url string, target interface{}) error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create HuggingFace API request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return fmt.Errorf("failed to make request to HuggingFace API: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return errors.New("invalid API key.")
}
if resp.StatusCode == http.StatusForbidden {
return errors.New("access to this repo is restricted and you are not in the authorized list. Visit the repository to ask for access.")
}
defer resp.Body.Close()
return json.NewDecoder(resp.Body).Decode(target)
}
// GetRepo retrieves repo from the Hugging Face API
func (c *HFClient) GetRepo(ctx context.Context, repoName string, resourceType string) (Repo, error) {
var repo Repo
url, err := buildAPIURL(c.BaseURL, resourceType, repoName)
if err != nil {
return repo, err
}
err = c.get(ctx, url, &repo)
return repo, err
}
// ListDiscussions retrieves discussions from the Hugging Face API
func (c *HFClient) ListDiscussions(ctx context.Context, repoInfo repoInfo) (DiscussionList, error) {
var discussions DiscussionList
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
if err != nil {
return discussions, err
}
url := fmt.Sprintf("%s/%s", baseURL, DiscussionsRoute)
err = c.get(ctx, url, &discussions)
return discussions, err
}
func (c *HFClient) GetDiscussionByID(ctx context.Context, repoInfo repoInfo, discussionID string) (Discussion, error) {
var discussion Discussion
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
if err != nil {
return discussion, err
}
url := fmt.Sprintf("%s/%s/%s", baseURL, DiscussionsRoute, discussionID)
err = c.get(ctx, url, &discussion)
return discussion, err
}
// ListReposByAuthor retrieves repos from the Hugging Face API by author (user or org)
// Note: not addressing pagination b/c allow by default 1000 results, which should be enough for 99.99% of cases
func (c *HFClient) ListReposByAuthor(ctx context.Context, resourceType string, author string) ([]Repo, error) {
var repos []Repo
url := fmt.Sprintf("%s/%s/%s?limit=1000&author=%s", c.BaseURL, APIRoute, getResourceAPIPath(resourceType), author)
err := c.get(ctx, url, &repos)
return repos, err
}
// getResourceAPIPath returns the API path for the given resource type
func getResourceAPIPath(resourceType string) string {
return apiPaths[resourceType]
}
// getResourceHTMLPath returns the HTML path for the given resource type
func getResourceHTMLPath(resourceType string) string {
return htmlPaths[resourceType]
}
func buildAPIURL(endpoint string, resourceType string, repoName string) (string, error) {
if endpoint == "" || resourceType == "" || repoName == "" {
return "", errors.New("endpoint, resourceType, and repoName must not be empty")
}
return fmt.Sprintf("%s/%s/%s/%s", endpoint, APIRoute, getResourceAPIPath(resourceType), repoName), nil
}

View file

@ -0,0 +1,680 @@
package huggingface
import (
"fmt"
"os"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-logr/logr"
"github.com/gobwas/glob"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
)
const (
SourceType = sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE
DatasetsRoute = "datasets"
SpacesRoute = "spaces"
ModelsAPIRoute = "models"
DiscussionsRoute = "discussions"
APIRoute = "api"
DATASET = "dataset"
MODEL = "model"
SPACE = "space"
defaultPagination = 100
)
type resourceType string
type Source struct {
name string
huggingfaceToken string
sourceID sources.SourceID
jobID sources.JobID
verify bool
useCustomContentWriter bool
orgsCache cache.Cache[string]
usersCache cache.Cache[string]
models []string
spaces []string
datasets []string
filteredModelsCache *filteredRepoCache
filteredSpacesCache *filteredRepoCache
filteredDatasetsCache *filteredRepoCache
repoInfoCache repoInfoCache
git *git.Git
scanOptions *git.ScanOptions
apiClient *HFClient
log logr.Logger
conn *sourcespb.Huggingface
jobPool *errgroup.Group
resumeInfoMutex sync.Mutex
resumeInfoSlice []string
skipAllModels bool
skipAllSpaces bool
skipAllDatasets bool
includeDiscussions bool
includePrs bool
sources.Progress
sources.CommonSourceUnitUnmarshaller
}
// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
// WithCustomContentWriter sets the useCustomContentWriter flag on the source.
func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true }
// Type returns the type of source.
// It is used for matching source types in configuration and job input.
func (s *Source) Type() sourcespb.SourceType {
return SourceType
}
func (s *Source) SourceID() sources.SourceID {
return s.sourceID
}
func (s *Source) JobID() sources.JobID {
return s.jobID
}
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
// based on include and exclude globs.
type filteredRepoCache struct {
cache.Cache[string]
include, exclude []glob.Glob
}
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
includeGlobs := make([]glob.Glob, 0, len(include))
excludeGlobs := make([]glob.Glob, 0, len(exclude))
for _, ig := range include {
g, err := glob.Compile(ig)
if err != nil {
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
continue
}
includeGlobs = append(includeGlobs, g)
}
for _, eg := range exclude {
g, err := glob.Compile(eg)
if err != nil {
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
continue
}
excludeGlobs = append(excludeGlobs, g)
}
return &filteredRepoCache{Cache: c, include: includeGlobs, exclude: excludeGlobs}
}
// Set overrides the cache.Cache Set method to filter out repos based on
// include and exclude globs.
func (c *filteredRepoCache) Set(key, val string) {
if c.ignoreRepo(key) {
return
}
if !c.includeRepo(key) {
return
}
c.Cache.Set(key, val)
}
func (c *filteredRepoCache) ignoreRepo(s string) bool {
for _, g := range c.exclude {
if g.Match(s) {
return true
}
}
return false
}
func (c *filteredRepoCache) includeRepo(s string) bool {
if len(c.include) == 0 {
return true
}
for _, g := range c.include {
if g.Match(s) {
return true
}
}
return false
}
// Init returns an initialized HuggingFace source.
func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
err := git.CmdCheck()
if err != nil {
return err
}
s.log = aCtx.Logger()
s.name = name
s.sourceID = sourceID
s.jobID = jobID
s.verify = verify
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)
var conn sourcespb.Huggingface
err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
return fmt.Errorf("error unmarshalling connection: %w", err)
}
s.conn = &conn
s.orgsCache = memory.New[string]()
for _, org := range s.conn.Organizations {
s.orgsCache.Set(org, org)
}
s.usersCache = memory.New[string]()
for _, user := range s.conn.Users {
s.usersCache.Set(user, user)
}
//Verify ignore and include models, spaces, and datasets are valid
// this ensures that calling --org <org> --ignore-model <org/model> contains the proper
// repo format of org/model. Otherwise, we would scan the entire org.
if err := s.validateIgnoreIncludeRepos(); err != nil {
return err
}
s.filteredModelsCache = s.newFilteredRepoCache(memory.New[string](),
append(s.conn.GetModels(), s.conn.GetIncludeModels()...),
s.conn.GetIgnoreModels(),
)
s.filteredSpacesCache = s.newFilteredRepoCache(memory.New[string](),
append(s.conn.GetSpaces(), s.conn.GetIncludeSpaces()...),
s.conn.GetIgnoreSpaces(),
)
s.filteredDatasetsCache = s.newFilteredRepoCache(memory.New[string](),
append(s.conn.GetDatasets(), s.conn.GetIncludeDatasets()...),
s.conn.GetIgnoreDatasets(),
)
s.models = initializeRepos(s.filteredModelsCache, s.conn.Models, fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s"))
s.spaces = initializeRepos(s.filteredSpacesCache, s.conn.Spaces, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s"))
s.datasets = initializeRepos(s.filteredDatasetsCache, s.conn.Datasets, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s"))
s.repoInfoCache = newRepoInfoCache()
s.includeDiscussions = s.conn.IncludeDiscussions
s.includePrs = s.conn.IncludePrs
cfg := &git.Config{
SourceName: s.name,
JobID: s.jobID,
SourceID: s.sourceID,
SourceType: s.Type(),
Verify: s.verify,
Concurrency: concurrency,
SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData {
return &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Huggingface{
Huggingface: &source_metadatapb.Huggingface{
Commit: sanitizer.UTF8(commit),
File: sanitizer.UTF8(file),
Email: sanitizer.UTF8(email),
Repository: sanitizer.UTF8(repository),
Link: giturl.GenerateLink(repository, commit, file, line),
Timestamp: sanitizer.UTF8(timestamp),
Line: line,
Visibility: s.visibilityOf(aCtx, repository),
ResourceType: s.getResourceType(aCtx, repository),
},
},
}
},
UseCustomContentWriter: s.useCustomContentWriter,
}
s.git = git.NewGit(cfg)
s.huggingfaceToken = s.conn.GetToken()
s.apiClient = NewHFClient(s.conn.Endpoint, s.huggingfaceToken, 10*time.Second)
s.skipAllModels = s.conn.SkipAllModels
s.skipAllSpaces = s.conn.SkipAllSpaces
s.skipAllDatasets = s.conn.SkipAllDatasets
return nil
}
func (s *Source) validateIgnoreIncludeRepos() error {
if err := verifySlashSeparatedStrings(s.conn.IgnoreModels); err != nil {
return err
}
if err := verifySlashSeparatedStrings(s.conn.IncludeModels); err != nil {
return err
}
if err := verifySlashSeparatedStrings(s.conn.IgnoreSpaces); err != nil {
return err
}
if err := verifySlashSeparatedStrings(s.conn.IncludeSpaces); err != nil {
return err
}
if err := verifySlashSeparatedStrings(s.conn.IgnoreDatasets); err != nil {
return err
}
if err := verifySlashSeparatedStrings(s.conn.IncludeDatasets); err != nil {
return err
}
return nil
}
func verifySlashSeparatedStrings(s []string) error {
for _, str := range s {
if !strings.Contains(str, "/") {
return fmt.Errorf("invalid owner/repo: %s", str)
}
}
return nil
}
func initializeRepos(cache *filteredRepoCache, repos []string, urlPattern string) []string {
returnRepos := make([]string, 0)
for _, repo := range repos {
if !cache.ignoreRepo(repo) {
url := fmt.Sprintf(urlPattern, repo)
cache.Set(repo, url)
returnRepos = append(returnRepos, repo)
}
}
return returnRepos
}
func (s *Source) getResourceType(ctx context.Context, repoURL string) string {
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
ctx.Logger().Error(err, "failed to get repository resource type")
return ""
}
return string(repoInfo.resourceType)
}
func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metadatapb.Visibility {
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
ctx.Logger().Error(err, "failed to get repository visibility")
return source_metadatapb.Visibility_unknown
}
return repoInfo.visibility
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
err := s.enumerate(ctx)
if err != nil {
return err
}
return s.scan(ctx, chunksChan)
}
func (s *Source) enumerate(ctx context.Context) error {
s.enumerateAuthors(ctx)
s.models = make([]string, 0, s.filteredModelsCache.Count())
for _, repo := range s.filteredModelsCache.Keys() {
if err := s.cacheRepoInfo(ctx, repo, MODEL, s.filteredModelsCache); err != nil {
continue
}
}
s.spaces = make([]string, 0, s.filteredSpacesCache.Count())
for _, repo := range s.filteredSpacesCache.Keys() {
if err := s.cacheRepoInfo(ctx, repo, SPACE, s.filteredSpacesCache); err != nil {
continue
}
}
s.datasets = make([]string, 0, s.filteredDatasetsCache.Count())
for _, repo := range s.filteredDatasetsCache.Keys() {
if err := s.cacheRepoInfo(ctx, repo, DATASET, s.filteredDatasetsCache); err != nil {
continue
}
}
s.log.Info("Completed enumeration", "num_models", len(s.models), "num_spaces", len(s.spaces), "num_datasets", len(s.datasets))
// We must sort the repos so we can resume later if necessary.
sort.Strings(s.models)
sort.Strings(s.datasets)
sort.Strings(s.spaces)
return nil
}
func (s *Source) cacheRepoInfo(ctx context.Context, repo string, repoType string, repoCache *filteredRepoCache) error {
repoURL, _ := repoCache.Get(repo)
repoCtx := context.WithValue(ctx, repoType, repoURL)
if _, ok := s.repoInfoCache.get(repoURL); !ok {
repoCtx.Logger().V(2).Info("Caching " + repoType + " info")
repo, err := s.apiClient.GetRepo(repoCtx, repo, repoType)
if err != nil {
repoCtx.Logger().Error(err, "failed to fetch "+repoType)
return err
}
// check if repo empty
if repo.RepoID == "" {
repoCtx.Logger().Error(fmt.Errorf("no repo found for repo"), repoURL)
return nil
}
s.repoInfoCache.put(repoURL, repoInfo{
owner: repo.Owner,
name: strings.Split(repo.RepoID, "/")[1],
fullName: repo.RepoID,
visibility: getVisibility(repo.IsPrivate),
resourceType: resourceType(repoType),
})
}
s.updateRepoLists(repoURL, repoType)
return nil
}
func getVisibility(isPrivate bool) source_metadatapb.Visibility {
if isPrivate {
return source_metadatapb.Visibility_private
}
return source_metadatapb.Visibility_public
}
func (s *Source) updateRepoLists(repoURL string, repoType string) {
switch repoType {
case MODEL:
s.models = append(s.models, repoURL)
case SPACE:
s.spaces = append(s.spaces, repoURL)
case DATASET:
s.datasets = append(s.datasets, repoURL)
}
}
func (s *Source) fetchAndCacheRepos(ctx context.Context, resourceType string, org string) error {
var repos []Repo
var err error
var url string
var filteredCache *filteredRepoCache
switch resourceType {
case MODEL:
filteredCache = s.filteredModelsCache
url = fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s")
repos, err = s.apiClient.ListReposByAuthor(ctx, MODEL, org)
case SPACE:
filteredCache = s.filteredSpacesCache
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s")
repos, err = s.apiClient.ListReposByAuthor(ctx, SPACE, org)
case DATASET:
filteredCache = s.filteredDatasetsCache
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s")
repos, err = s.apiClient.ListReposByAuthor(ctx, DATASET, org)
}
if err != nil {
return err
}
for _, repo := range repos {
repoURL := fmt.Sprintf(url, repo.RepoID)
filteredCache.Set(repo.RepoID, repoURL)
if err := s.cacheRepoInfo(ctx, repo.RepoID, resourceType, filteredCache); err != nil {
continue
}
}
return nil
}
func (s *Source) enumerateAuthors(ctx context.Context) {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "organization", org)
if !s.skipAllModels {
if err := s.fetchAndCacheRepos(orgCtx, MODEL, org); err != nil {
orgCtx.Logger().Error(err, "Failed to fetch models for organization")
continue
}
}
if !s.skipAllSpaces {
if err := s.fetchAndCacheRepos(orgCtx, SPACE, org); err != nil {
orgCtx.Logger().Error(err, "Failed to fetch spaces for organization")
continue
}
}
if !s.skipAllDatasets {
if err := s.fetchAndCacheRepos(orgCtx, DATASET, org); err != nil {
orgCtx.Logger().Error(err, "Failed to fetch datasets for organization")
continue
}
}
}
for _, user := range s.usersCache.Keys() {
userCtx := context.WithValue(ctx, "user", user)
if !s.skipAllModels {
if err := s.fetchAndCacheRepos(userCtx, MODEL, user); err != nil {
userCtx.Logger().Error(err, "Failed to fetch models for user")
continue
}
}
if !s.skipAllSpaces {
if err := s.fetchAndCacheRepos(userCtx, SPACE, user); err != nil {
userCtx.Logger().Error(err, "Failed to fetch spaces for user")
continue
}
}
if !s.skipAllDatasets {
if err := s.fetchAndCacheRepos(userCtx, DATASET, user); err != nil {
userCtx.Logger().Error(err, "Failed to fetch datasets for user")
continue
}
}
}
}
func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk, resourceType string) error {
var scannedCount uint64 = 1
repos := s.getReposListByType(resourceType)
s.log.V(2).Info("Found "+resourceType+" to scan", "count", len(repos))
// If there is resume information available, limit this scan to only the repos that still need scanning.
reposToScan, progressIndexOffset := sources.FilterReposToResume(repos, s.GetProgress().EncodedResumeInfo)
repos = reposToScan
scanErrs := sources.NewScanErrors()
if s.scanOptions == nil {
s.scanOptions = &git.ScanOptions{}
}
for i, repoURL := range repos {
i, repoURL := i, repoURL
s.jobPool.Go(func() error {
if common.IsDone(ctx) {
return nil
}
// TODO: set progress complete is being called concurrently with i
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL, resourceType, repos)
// Ensure the repo is removed from the resume info after being scanned.
defer func(s *Source, repoURL string) {
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
}(s, repoURL)
// Scan the repository
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
s.log.Error(err, "failed to scan "+resourceType)
return nil
}
repoCtx := context.WithValues(ctx, resourceType, repoURL)
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan)
if err != nil {
scanErrs.Add(err)
return nil
}
// Scan discussions and PRs, if enabled.
if s.includeDiscussions || s.includePrs {
if err = s.scanDiscussions(repoCtx, repoInfo, chunksChan); err != nil {
scanErrs.Add(fmt.Errorf("error scanning discussions/PRs in repo %s: %w", repoURL, err))
return nil
}
}
repoCtx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d "+resourceType+"s", scannedCount, len(s.models)), "duration_seconds", duration)
atomic.AddUint64(&scannedCount, 1)
return nil
})
}
_ = s.jobPool.Wait()
if scanErrs.Count() > 0 {
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
}
s.SetProgressComplete(len(repos), len(repos), "Completed HuggingFace "+resourceType+" scan", "")
return nil
}
func (s *Source) getReposListByType(resourceType string) []string {
switch resourceType {
case MODEL:
return s.models
case SPACE:
return s.spaces
case DATASET:
return s.datasets
}
return nil
}
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
if err := s.scanRepos(ctx, chunksChan, MODEL); err != nil {
return err
}
if err := s.scanRepos(ctx, chunksChan, SPACE); err != nil {
return err
}
if err := s.scanRepos(ctx, chunksChan, DATASET); err != nil {
return err
}
return nil
}
func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
ctx.Logger().V(2).Info("attempting to clone %s", repoInfo.resourceType)
path, repo, err := s.cloneRepo(ctx, repoURL)
if err != nil {
return 0, err
}
defer os.RemoveAll(path)
var logger logr.Logger
logger.V(2).Info("scanning %s", repoInfo.resourceType)
start := time.Now()
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil {
return 0, fmt.Errorf("error scanning repo %s: %w", repoURL, err)
}
return time.Since(start), nil
}
// setProgressCompleteWithRepo calls the s.SetProgressComplete after safely setting up the encoded resume info string.
func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL string, resourceType string, repos []string) {
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
// Add the repoURL to the resume info slice.
s.resumeInfoSlice = append(s.resumeInfoSlice, repoURL)
sort.Strings(s.resumeInfoSlice)
// Make the resume info string from the slice.
encodedResumeInfo := sources.EncodeResumeInfo(s.resumeInfoSlice)
s.SetProgressComplete(index+offset, len(repos)+offset, fmt.Sprintf("%ss: %s", resourceType, repoURL), encodedResumeInfo)
}
func (s *Source) scanDiscussions(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
discussions, err := s.apiClient.ListDiscussions(ctx, repoInfo)
if err != nil {
return err
}
for _, discussion := range discussions.Discussions {
if (discussion.IsPR && s.includePrs) || (!discussion.IsPR && s.includeDiscussions) {
d, err := s.apiClient.GetDiscussionByID(ctx, repoInfo, discussion.GetID())
if err != nil {
return err
}
// Note: there is no discussion "description" or similar to chunk, only comments
if err = s.chunkDiscussionComments(ctx, repoInfo, d, chunksChan); err != nil {
return err
}
}
}
return nil
}
func (s *Source) chunkDiscussionComments(ctx context.Context, repoInfo repoInfo, discussion Discussion, chunksChan chan *sources.Chunk) error {
for _, comment := range discussion.Events {
chunk := &sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
SourceType: s.Type(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Huggingface{
Huggingface: &source_metadatapb.Huggingface{
Link: sanitizer.UTF8(fmt.Sprintf("%s/%s#%s", s.conn.Endpoint, discussion.GetDiscussionPath(), comment.GetID())),
Username: sanitizer.UTF8(comment.GetAuthor()),
Repository: sanitizer.UTF8(fmt.Sprintf("%s/%s", s.conn.Endpoint, discussion.GetGitPath())),
Timestamp: sanitizer.UTF8(comment.GetCreatedAt()),
Visibility: repoInfo.visibility,
},
},
},
Data: []byte(comment.Data.Latest.Raw),
Verify: s.verify,
}
select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
}
}
return nil
}

View file

@ -0,0 +1,651 @@
package huggingface
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
)
const (
TEST_TOKEN = "test token"
)
func initTestClient() *HFClient {
return NewHFClient("https://huggingface.co", TEST_TOKEN, 10*time.Second)
}
func TestGetRepo(t *testing.T) {
resourceType := MODEL
repoName := "test-model"
repoOwner := "test-author"
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(200).
JSON(map[string]interface{}{
"id": repoOwner + "/" + repoName,
"author": repoOwner,
"private": true,
})
client := initTestClient()
model, err := client.GetRepo(context.Background(), repoName, resourceType)
assert.Nil(t, err)
assert.NotNil(t, model)
assert.Equal(t, repoOwner+"/"+repoName, model.RepoID)
assert.Equal(t, repoOwner, model.Owner)
assert.Equal(t, true, model.IsPrivate)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetRepo_NotFound(t *testing.T) {
resourceType := MODEL
repoName := "doesnotexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(404).
JSON(map[string]interface{}{
"id": "",
"author": "",
"private": false,
})
client := initTestClient()
model, err := client.GetRepo(context.Background(), repoName, resourceType)
assert.Nil(t, err)
assert.NotNil(t, model)
assert.Equal(t, "", model.RepoID)
assert.Equal(t, "", model.Owner)
assert.Equal(t, false, model.IsPrivate)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetModel_Error(t *testing.T) {
resourceType := MODEL
repoName := "doesnotexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(500)
client := initTestClient()
model, err := client.GetRepo(context.Background(), repoName, resourceType)
assert.NotNil(t, err)
assert.NotNil(t, model)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListDiscussions(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/test-model",
resourceType: MODEL,
}
jsonBlob := `{
"discussions": [
{
"num": 2,
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false
},
"repo": {
"name": "test-author/test-model",
"type": "model"
},
"title": "new PR",
"status": "open",
"createdAt": "2024-06-18T14:34:21.000Z",
"isPullRequest": true,
"numComments": 2,
"pinned": false
},
{
"num": 1,
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false
},
"repo": {
"name": "test-author/test-model",
"type": "model"
},
"title": "secret in comment",
"status": "closed",
"createdAt": "2024-06-18T14:31:57.000Z",
"isPullRequest": false,
"numComments": 2,
"pinned": false
}
],
"count": 2,
"start": 0,
"numClosedDiscussions": 1
}`
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(200).
JSON(jsonBlob)
client := initTestClient()
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
assert.Nil(t, err)
assert.NotNil(t, discussions)
assert.Equal(t, 2, len(discussions.Discussions))
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListDiscussions_NotFound(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/doesnotexist",
resourceType: MODEL,
}
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(404).
JSON(map[string]interface{}{
"discussions": []map[string]interface{}{},
})
client := initTestClient()
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
assert.Nil(t, err)
assert.NotNil(t, discussions)
assert.Equal(t, 0, len(discussions.Discussions))
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListDiscussions_Error(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/doesnotexist",
resourceType: MODEL,
}
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(500)
client := initTestClient()
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
assert.NotNil(t, err)
assert.NotNil(t, discussions)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetDiscussionByID(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/test-model",
resourceType: MODEL,
}
discussionID := "1"
jsonBlob := `{
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false
},
"num": 1,
"repo": {
"name": "test-author/test-model",
"type": "model"
},
"title": "secret in initial",
"status": "open",
"createdAt": "2024-06-18T14:31:46.000Z",
"events": [
{
"id": "525",
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false,
"isOwner": true,
"isOrgMember": false
},
"createdAt": "2024-06-18T14:31:46.000Z",
"type": "comment",
"data": {
"edited": true,
"hidden": false,
"latest": {
"raw": "dd",
"html": "<p>dd</p>\n",
"updatedAt": "2024-06-18T14:33:32.066Z",
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false
}
},
"numEdits": 1,
"editors": ["trufflej"],
"reactions": [],
"identifiedLanguage": {
"language": "en",
"probability": 0.40104949474334717
},
"isReport": false
}
},
{
"id": "526",
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false,
"isOwner": true,
"isOrgMember": false
},
"createdAt": "2024-06-18T14:32:40.000Z",
"type": "status-change",
"data": {
"status": "closed"
}
},
{
"id": "527",
"author": {
"avatarUrl": "/avatars/test.svg",
"fullname": "TEST",
"name": "test-author",
"type": "user",
"isPro": false,
"isHf": false,
"isMod": false,
"isOwner": true,
"isOrgMember": false
},
"createdAt": "2024-06-18T14:33:27.000Z",
"type": "status-change",
"data": {
"status": "open"
}
}
],
"pinned": false,
"locked": false,
"isPullRequest": false,
"isReport": false
}`
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(200).
JSON(jsonBlob)
client := initTestClient()
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
assert.Nil(t, err)
assert.NotNil(t, discussion)
assert.Equal(t, discussionID, strconv.Itoa(discussion.ID))
assert.Equal(t, 3, len(discussion.Events))
assert.Equal(t, false, discussion.IsPR)
assert.Equal(t, "secret in initial", discussion.Title)
assert.Equal(t, repoInfo.fullName, discussion.Repo.FullName)
assert.Equal(t, string(repoInfo.resourceType), discussion.Repo.ResourceType)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetDiscussionByID_NotFound(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/test-model",
resourceType: MODEL,
}
discussionID := "doesnotexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(404).
JSON(map[string]interface{}{})
client := initTestClient()
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
assert.Nil(t, err)
assert.NotNil(t, discussion)
assert.Equal(t, 0, len(discussion.Events))
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetDiscussionByID_Error(t *testing.T) {
repoInfo := repoInfo{
fullName: "test-author/test-model",
resourceType: MODEL,
}
discussionID := "doesnotexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(500)
client := initTestClient()
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
assert.NotNil(t, err)
assert.NotNil(t, discussion)
assert.Equal(t, "", discussion.Title)
assert.Equal(t, 0, len(discussion.Events))
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListReposByAuthor(t *testing.T) {
resourceType := MODEL
author := "test-author"
repo := "test-model"
repo2 := "test-model2"
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
MatchParam("author", author).
MatchParam("limit", "1000").
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "1",
"id": author + "/" + repo,
"modelId": author + "/" + repo,
"private": true,
},
{
"_id": "2",
"id": author + "/" + repo2,
"modelId": author + "/" + repo2,
"private": false,
},
})
for _, mock := range gock.Pending() {
fmt.Println(mock.Request().URLStruct.String())
}
client := initTestClient()
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
assert.Nil(t, err)
assert.NotNil(t, repos)
assert.Equal(t, 2, len(repos))
// count of repos with private flag
countOfPrivateRepos := 0
for _, repo := range repos {
if repo.IsPrivate {
countOfPrivateRepos++
}
}
assert.Equal(t, 1, countOfPrivateRepos)
// there is no author field in JSON, so assert repo.Owner is empty
for _, repo := range repos {
assert.Equal(t, "", repo.Owner)
}
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListReposByAuthor_NotFound(t *testing.T) {
resourceType := MODEL
author := "authordoesntexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
MatchParam("author", author).
MatchParam("limit", "1000").
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(404).
JSON([]map[string]interface{}{})
client := initTestClient()
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
assert.Nil(t, err)
assert.NotNil(t, repos)
assert.Equal(t, 0, len(repos))
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestListReposByAuthor_Error(t *testing.T) {
resourceType := MODEL
author := "doesnotexist"
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
MatchParam("author", author).
MatchParam("limit", "1000").
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
Reply(500)
client := initTestClient()
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
assert.NotNil(t, err)
assert.Nil(t, repos)
assert.False(t, gock.HasUnmatchedRequest())
assert.True(t, gock.IsDone())
}
func TestGetResourceAPIPath(t *testing.T) {
assert.Equal(t, "models", getResourceAPIPath(MODEL))
assert.Equal(t, "datasets", getResourceAPIPath(DATASET))
assert.Equal(t, "spaces", getResourceAPIPath(SPACE))
}
func TestGetResourceHTMLPath(t *testing.T) {
assert.Equal(t, "", getResourceHTMLPath(MODEL))
assert.Equal(t, "datasets", getResourceHTMLPath(DATASET))
assert.Equal(t, "spaces", getResourceHTMLPath(SPACE))
}
func TestBuildAPIURL_ValidInputs(t *testing.T) {
endpoint := "https://huggingface.co"
resourceType := MODEL
repoName := "test-repo"
expectedURL := "https://huggingface.co/api/models/test-repo"
url, err := buildAPIURL(endpoint, resourceType, repoName)
assert.Nil(t, err)
assert.Equal(t, expectedURL, url)
}
func TestBuildAPIURL_EmptyEndpoint(t *testing.T) {
endpoint := ""
resourceType := MODEL
repoName := "test-repo"
url, err := buildAPIURL(endpoint, resourceType, repoName)
assert.NotNil(t, err)
assert.Equal(t, "", url)
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
}
func TestBuildAPIURL_EmptyResourceType(t *testing.T) {
endpoint := "https://huggingface.co"
resourceType := ""
repoName := "test-repo"
url, err := buildAPIURL(endpoint, resourceType, repoName)
assert.NotNil(t, err)
assert.Equal(t, "", url)
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
}
func TestBuildAPIURL_EmptyRepoName(t *testing.T) {
endpoint := "https://huggingface.co"
resourceType := "model"
repoName := ""
url, err := buildAPIURL(endpoint, resourceType, repoName)
assert.NotNil(t, err)
assert.Equal(t, "", url)
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
}
func TestGetDiscussionPath_ModelResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-model",
ResourceType: "model",
},
ID: 1,
}
expectedPath := "test-author/test-model/discussions/1"
path := discussion.GetDiscussionPath()
assert.Equal(t, expectedPath, path)
}
func TestGetDiscussionPath_DatasetResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-dataset",
ResourceType: "dataset",
},
ID: 1,
}
expectedPath := "datasets/test-author/test-dataset/discussions/1"
path := discussion.GetDiscussionPath()
assert.Equal(t, expectedPath, path)
}
func TestGetDiscussionPath_SpaceResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-space",
ResourceType: "space",
},
ID: 1,
}
expectedPath := "spaces/test-author/test-space/discussions/1"
path := discussion.GetDiscussionPath()
assert.Equal(t, expectedPath, path)
}
func TestGetGitPath_ModelResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-model",
ResourceType: "model",
},
ID: 1,
}
expectedPath := "test-author/test-model.git"
path := discussion.GetGitPath()
assert.Equal(t, expectedPath, path)
}
func TestGetGitPath_DatasetResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-dataset",
ResourceType: "dataset",
},
ID: 1,
}
expectedPath := "datasets/test-author/test-dataset.git"
path := discussion.GetGitPath()
assert.Equal(t, expectedPath, path)
}
func TestGetGitPath_SpaceResource(t *testing.T) {
discussion := Discussion{
Repo: RepoData{
FullName: "test-author/test-space",
ResourceType: "space",
},
ID: 1,
}
expectedPath := "spaces/test-author/test-space.git"
path := discussion.GetGitPath()
assert.Equal(t, expectedPath, path)
}

View file

@ -0,0 +1,986 @@
package huggingface
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"google.golang.org/protobuf/types/known/anypb"
"gopkg.in/h2non/gock.v1"
)
func createTestSource(src *sourcespb.Huggingface) (*Source, *anypb.Any) {
s := &Source{}
conn, err := anypb.New(src)
if err != nil {
panic(err)
}
return s, conn
}
// test include exclude ignore/include orgs, users
func TestInit(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Models: []string{"user/model1", "user/model2", "user/ignorethismodel"},
IgnoreModels: []string{"user/ignorethismodel"},
Spaces: []string{"user/space1", "user/space2", "user/ignorethisspace"},
IgnoreSpaces: []string{"user/ignorethisspace"},
Datasets: []string{"user/dataset1", "user/dataset2", "user/ignorethisdataset"},
IgnoreDatasets: []string{"user/ignorethisdataset"},
Organizations: []string{"org1", "org2"},
Users: []string{"user1", "user2"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"user/model1", "user/model2"}, s.models)
for _, model := range s.models {
modelURL, _ := s.filteredModelsCache.Get(model)
assert.Equal(t, modelURL, s.conn.Endpoint+"/"+model+".git")
}
assert.ElementsMatch(t, []string{"user/space1", "user/space2"}, s.spaces)
for _, space := range s.spaces {
spaceURL, _ := s.filteredSpacesCache.Get(space)
assert.Equal(t, spaceURL, s.conn.Endpoint+"/"+getResourceHTMLPath(SPACE)+"/"+space+".git")
}
assert.ElementsMatch(t, []string{"user/dataset1", "user/dataset2"}, s.datasets)
for _, dataset := range s.datasets {
datasetURL, _ := s.filteredDatasetsCache.Get(dataset)
assert.Equal(t, datasetURL, s.conn.Endpoint+"/"+getResourceHTMLPath(DATASET)+"/"+dataset+".git")
}
assert.ElementsMatch(t, s.conn.Organizations, s.orgsCache.Keys())
assert.ElementsMatch(t, s.conn.Users, s.usersCache.Keys())
}
func TestGetResourceType(t *testing.T) {
repo := "author/model1"
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Models: []string{repo},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get("/api/models/author/model1").
Reply(200).
JSON(map[string]interface{}{
"id": "author/model1",
"author": "author",
"private": true,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
assert.Equal(t, MODEL, s.getResourceType(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
}
func TestVisibilityOf(t *testing.T) {
repo := "author/model1"
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Models: []string{repo},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get("/api/models/author/model1").
Reply(200).
JSON(map[string]interface{}{
"id": "author/model1",
"author": "author",
"private": true,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
assert.Equal(t, source_metadatapb.Visibility(1), s.visibilityOf(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
}
func TestEnumerate(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Models: []string{"author/model1"},
Datasets: []string{"author/dataset1"},
Spaces: []string{"author/space1"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get("/api/models/author/model1").
Reply(200).
JSON(map[string]interface{}{
"id": "author/model1",
"author": "author",
"private": true,
})
gock.New("https://huggingface.co").
Get("/api/datasets/author/dataset1").
Reply(200).
JSON(map[string]interface{}{
"id": "author/dataset1",
"author": "author",
"private": false,
})
gock.New("https://huggingface.co").
Get("/api/spaces/author/space1").
Reply(200).
JSON(map[string]interface{}{
"id": "author/space1",
"author": "author",
"private": false,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/author/model1.git"
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
assert.Equal(t, []string{modelGitURL}, s.models)
assert.Equal(t, []string{datasetGitURL}, s.datasets)
assert.Equal(t, []string{spaceGitURL}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_private,
resourceType: MODEL,
owner: "author",
name: "model1",
fullName: "author/model1",
})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_public,
resourceType: DATASET,
owner: "author",
name: "dataset1",
fullName: "author/dataset1",
})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_public,
resourceType: SPACE,
owner: "author",
name: "space1",
fullName: "author/space1",
})
}
func TestUpdateRepoList(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/author/model1.git"
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
s.updateRepoLists(modelGitURL, MODEL)
s.updateRepoLists(datasetGitURL, DATASET)
s.updateRepoLists(spaceGitURL, SPACE)
assert.Equal(t, []string{modelGitURL}, s.models)
assert.Equal(t, []string{datasetGitURL}, s.datasets)
assert.Equal(t, []string{spaceGitURL}, s.spaces)
}
func TestGetReposListByType(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Models: []string{"author/model1", "author/model2"},
Datasets: []string{"author/dataset1", "author/dataset2"},
Spaces: []string{"author/space1", "author/space2"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
assert.Equal(t, s.getReposListByType(MODEL), s.models)
assert.Equal(t, s.getReposListByType(DATASET), s.datasets)
assert.Equal(t, s.getReposListByType(SPACE), s.spaces)
}
func TestGetVisibility(t *testing.T) {
assert.Equal(t, source_metadatapb.Visibility(1), getVisibility(true))
assert.Equal(t, source_metadatapb.Visibility(0), getVisibility(false))
}
func TestEnumerateAuthorsOrg(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "1",
"id": "org/model",
"modelId": "org/model",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/models/org/model").
Reply(200).
JSON(map[string]interface{}{
"id": "org/model",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "3",
"id": "org/dataset",
"modelId": "org/dataset",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/org/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "org/dataset",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "5",
"id": "org/space",
"modelId": "org/space",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/org/space").
Reply(200).
JSON(map[string]interface{}{
"id": "org/space",
"author": "org",
"private": true,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/org/model.git"
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
assert.Equal(t, []string{modelGitURL}, s.models)
assert.Equal(t, []string{datasetGitURL}, s.datasets)
assert.Equal(t, []string{spaceGitURL}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_private,
resourceType: MODEL,
owner: "org",
name: "model",
fullName: "org/model",
})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_private,
resourceType: DATASET,
owner: "org",
name: "dataset",
fullName: "org/dataset",
})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_private,
resourceType: SPACE,
owner: "org",
name: "space",
fullName: "org/space",
})
}
func TestEnumerateAuthorsOrgSkipAll(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
SkipAllModels: true,
SkipAllDatasets: true,
SkipAllSpaces: true,
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "1",
"id": "org/model",
"modelId": "org/model",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/models/org/model").
Reply(200).
JSON(map[string]interface{}{
"id": "org/model",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "3",
"id": "org/dataset",
"modelId": "org/dataset",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/org/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "org/dataset",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "5",
"id": "org/space",
"modelId": "org/space",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/org/space").
Reply(200).
JSON(map[string]interface{}{
"id": "org/space",
"author": "org",
"private": true,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/org/model.git"
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
assert.Equal(t, []string{}, s.models)
assert.Equal(t, []string{}, s.datasets)
assert.Equal(t, []string{}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{})
}
func TestEnumerateAuthorsOrgIgnores(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
IgnoreModels: []string{"org/model"},
IgnoreDatasets: []string{"org/dataset"},
IgnoreSpaces: []string{"org/space"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
// mock the request to the huggingface api
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "1",
"id": "org/model",
"modelId": "org/model",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/models/org/model").
Reply(200).
JSON(map[string]interface{}{
"id": "org/model",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "3",
"id": "org/dataset",
"modelId": "org/dataset",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/org/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "org/dataset",
"author": "org",
"private": true,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "org").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "5",
"id": "org/space",
"modelId": "org/space",
"private": true,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/org/space").
Reply(200).
JSON(map[string]interface{}{
"id": "org/space",
"author": "org",
"private": true,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/org/model.git"
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
assert.Equal(t, []string{}, s.models)
assert.Equal(t, []string{}, s.datasets)
assert.Equal(t, []string{}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{})
}
func TestEnumerateAuthorsUser(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Users: []string{"user"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "2",
"id": "user/model",
"modelId": "user/model",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/models/user/model").
Reply(200).
JSON(map[string]interface{}{
"id": "user/model",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "4",
"id": "user/dataset",
"modelId": "user/dataset",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/user/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "user/dataset",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "6",
"id": "user/space",
"modelId": "user/space",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/user/space").
Reply(200).
JSON(map[string]interface{}{
"id": "user/space",
"author": "user",
"private": false,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/user/model.git"
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
assert.Equal(t, []string{modelGitURL}, s.models)
assert.Equal(t, []string{datasetGitURL}, s.datasets)
assert.Equal(t, []string{spaceGitURL}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_public,
resourceType: MODEL,
owner: "user",
name: "model",
fullName: "user/model",
})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_public,
resourceType: DATASET,
owner: "user",
name: "dataset",
fullName: "user/dataset",
})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{
visibility: source_metadatapb.Visibility_public,
resourceType: SPACE,
owner: "user",
name: "space",
fullName: "user/space",
})
}
func TestEnumerateAuthorsUserSkipAll(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Users: []string{"user"},
SkipAllModels: true,
SkipAllDatasets: true,
SkipAllSpaces: true,
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "2",
"id": "user/model",
"modelId": "user/model",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/models/user/model").
Reply(200).
JSON(map[string]interface{}{
"id": "user/model",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "4",
"id": "user/dataset",
"modelId": "user/dataset",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/user/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "user/dataset",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "6",
"id": "user/space",
"modelId": "user/space",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/user/space").
Reply(200).
JSON(map[string]interface{}{
"id": "user/space",
"author": "user",
"private": false,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/user/model.git"
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
assert.Equal(t, []string{}, s.models)
assert.Equal(t, []string{}, s.datasets)
assert.Equal(t, []string{}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{})
}
func TestEnumerateAuthorsUserIgnores(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Users: []string{"user"},
IgnoreModels: []string{"user/model"},
IgnoreDatasets: []string{"user/dataset"},
IgnoreSpaces: []string{"user/space"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.Nil(t, err)
defer gock.Off()
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "2",
"id": "user/model",
"modelId": "user/model",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/models/user/model").
Reply(200).
JSON(map[string]interface{}{
"id": "user/model",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "4",
"id": "user/dataset",
"modelId": "user/dataset",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/datasets/user/dataset").
Reply(200).
JSON(map[string]interface{}{
"id": "user/dataset",
"author": "user",
"private": false,
})
gock.New("https://huggingface.co").
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
MatchParam("author", "user").
MatchParam("limit", "1000").
Reply(200).
JSON([]map[string]interface{}{
{
"_id": "6",
"id": "user/space",
"modelId": "user/space",
"private": false,
},
})
gock.New("https://huggingface.co").
Get("/api/spaces/user/space").
Reply(200).
JSON(map[string]interface{}{
"id": "user/space",
"author": "user",
"private": false,
})
err = s.enumerate(context.Background())
assert.Nil(t, err)
modelGitURL := "https://huggingface.co/user/model.git"
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
assert.Equal(t, []string{}, s.models)
assert.Equal(t, []string{}, s.datasets)
assert.Equal(t, []string{}, s.spaces)
r, _ := s.repoInfoCache.get(modelGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(datasetGitURL)
assert.Equal(t, r, repoInfo{})
r, _ = s.repoInfoCache.get(spaceGitURL)
assert.Equal(t, r, repoInfo{})
}
func TestVerifySlashSeparatedStrings(t *testing.T) {
assert.Error(t, verifySlashSeparatedStrings([]string{"orgmodel"}))
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model"}))
assert.Error(t, verifySlashSeparatedStrings([]string{"org/model", "orgmodel2"}))
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model", "org/model2"}))
}
func TestValidateIgnoreIncludeReposRepos(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
IgnoreModels: []string{"orgmodel1"},
IgnoreDatasets: []string{"org/dataset1"},
IgnoreSpaces: []string{"org/space1"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.NotNil(t, err)
}
func TestValidateIgnoreIncludeReposDatasets(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
IgnoreModels: []string{"org/model1"},
IgnoreDatasets: []string{"orgdataset1"},
IgnoreSpaces: []string{"org/space1"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.NotNil(t, err)
}
func TestValidateIgnoreIncludeReposSpaces(t *testing.T) {
s, conn := createTestSource(&sourcespb.Huggingface{
Endpoint: "https://huggingface.co",
Credential: &sourcespb.Huggingface_Token{
Token: "super secret token",
},
Organizations: []string{"org"},
IgnoreModels: []string{"org/model1"},
IgnoreDatasets: []string{"org/dataset1"},
IgnoreSpaces: []string{"orgspace1"},
})
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
assert.NotNil(t, err)
}
// repeat this with all skip flags, and then include/ignore flags

View file

@ -0,0 +1,72 @@
package huggingface
import (
"fmt"
"sync"
gogit "github.com/go-git/go-git/v5"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
)
type repoInfoCache struct {
mu sync.RWMutex
cache map[string]repoInfo
}
func newRepoInfoCache() repoInfoCache {
return repoInfoCache{
cache: make(map[string]repoInfo),
}
}
func (r *repoInfoCache) put(repoURL string, info repoInfo) {
r.mu.Lock()
defer r.mu.Unlock()
r.cache[repoURL] = info
}
func (r *repoInfoCache) get(repoURL string) (repoInfo, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
info, ok := r.cache[repoURL]
return info, ok
}
type repoInfo struct {
owner string
name string
fullName string
visibility source_metadatapb.Visibility
resourceType resourceType
}
func (s *Source) cloneRepo(
ctx context.Context,
repoURL string,
) (string, *gogit.Repository, error) {
var (
path string
repo *gogit.Repository
err error
)
switch s.conn.GetCredential().(type) {
case *sourcespb.Huggingface_Unauthenticated:
path, repo, err = git.CloneRepoUsingUnauthenticated(ctx, repoURL)
if err != nil {
return "", nil, err
}
case *sourcespb.Huggingface_Token:
path, repo, err = git.CloneRepoUsingToken(ctx, s.huggingfaceToken, repoURL, "")
if err != nil {
return "", nil, err
}
default:
return "", nil, fmt.Errorf("unhandled credential type for repo %s: %T", repoURL, s.conn.GetCredential())
}
return path, repo, nil
}

View file

@ -136,6 +136,19 @@ message GCS {
string content_type = 8;
}
message Huggingface {
string link = 1;
string username = 2;
string repository = 3;
string commit = 4;
string email = 5;
string file = 6;
string timestamp = 7;
int64 line = 8;
Visibility visibility = 9;
string resource_type = 10;
}
message Jira {
string issue = 1;
string author = 2;
@ -351,5 +364,6 @@ message MetaData {
Postman postman = 29;
Webhook webhook = 30;
Elasticsearch elasticsearch = 31;
Huggingface huggingface = 32;
}
}

View file

@ -48,6 +48,7 @@ enum SourceType {
SOURCE_TYPE_POSTMAN = 33;
SOURCE_TYPE_WEBHOOK = 34;
SOURCE_TYPE_ELASTICSEARCH = 35;
SOURCE_TYPE_HUGGINGFACE = 36;
}
message LocalSource {
@ -248,6 +249,30 @@ message GoogleDrive {
}
}
message Huggingface {
string endpoint = 1 [(validate.rules).string.uri_ref = true];
oneof credential {
string token = 2;
credentials.Unauthenticated unauthenticated = 3;
}
repeated string models = 4;
repeated string spaces = 5;
repeated string datasets = 12;
repeated string organizations = 6;
repeated string users = 15;
repeated string ignore_models = 7;
repeated string include_models = 8;
repeated string ignore_spaces = 9;
repeated string include_spaces = 10;
repeated string ignore_datasets = 13;
repeated string include_datasets = 14;
bool skip_all_models = 16;
bool skip_all_spaces = 17;
bool skip_all_datasets = 18;
bool include_discussions = 11;
bool include_prs = 19;
}
message JIRA {
string endpoint = 1 [(validate.rules).string.uri_ref = true];
oneof credential {