mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
New Source: HuggingFace (#3000)
* initial spike on hf * added in user and org enum * adding huggingface source * updated with lint suggestions * updated readme * addressing resources that require org approval to access * removing unneeded code * updating with new error msg for 403 * deleted unused code + added resource check in main
This commit is contained in:
parent
e9206c66bb
commit
01a1499600
14 changed files with 4568 additions and 964 deletions
21
README.md
21
README.md
|
@ -301,6 +301,27 @@ trufflehog elasticsearch \
|
|||
--api-key 'MlVtVjBZ...ZSYlduYnF1djh3NG5FQQ=='
|
||||
```
|
||||
|
||||
## 15. Scan HuggingFace
|
||||
|
||||
### Scan a HuggingFace Model, Dataset or Space
|
||||
|
||||
```bash
|
||||
trufflehog huggingface --model <username/modelname> --space <username/spacename> --dataset <username/datasetname>
|
||||
```
|
||||
|
||||
### Scan all Models, Datasets and Space belonging to a HuggingFace Org/User
|
||||
|
||||
```bash
|
||||
trufflehog huggingface --org <orgname> --user <username>
|
||||
```
|
||||
|
||||
Optionally, skip scanning a type of resource with `--skip-models`, `--skip-datasets`, `--skip-spaces` or a particular resource with `--ignore-models/datasets/spaces <resource-name>`.
|
||||
|
||||
### Scan Discussion and PR Comments
|
||||
```bash
|
||||
trufflehog huggingface --model <username/modelname> --include-discussions --include-prs
|
||||
```
|
||||
|
||||
# :question: FAQ
|
||||
|
||||
- All I see is `🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷` and the program exits, what gives?
|
||||
|
|
54
main.go
54
main.go
|
@ -198,6 +198,27 @@ var (
|
|||
jenkinsPassword = jenkinsScan.Flag("password", "Jenkins password").Envar("JENKINS_PASSWORD").String()
|
||||
jenkinsInsecureSkipVerifyTLS = jenkinsScan.Flag("insecure-skip-verify-tls", "Skip TLS verification").Envar("JENKINS_INSECURE_SKIP_VERIFY_TLS").Bool()
|
||||
|
||||
huggingfaceScan = cli.Command("huggingface", "Find credentials in HuggingFace datasets, models and spaces.")
|
||||
huggingfaceEndpoint = huggingfaceScan.Flag("endpoint", "HuggingFace endpoint.").Default("https://huggingface.co").String()
|
||||
huggingfaceModels = huggingfaceScan.Flag("model", "HuggingFace model to scan. You can repeat this flag. Example: 'username/model'").Strings()
|
||||
huggingfaceSpaces = huggingfaceScan.Flag("space", "HuggingFace space to scan. You can repeat this flag. Example: 'username/space'").Strings()
|
||||
huggingfaceDatasets = huggingfaceScan.Flag("dataset", "HuggingFace dataset to scan. You can repeat this flag. Example: 'username/dataset'").Strings()
|
||||
huggingfaceOrgs = huggingfaceScan.Flag("org", `HuggingFace organization to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
|
||||
huggingfaceUsers = huggingfaceScan.Flag("user", `HuggingFace user to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
|
||||
huggingfaceToken = huggingfaceScan.Flag("token", "HuggingFace token. Can be provided with environment variable HUGGINGFACE_TOKEN.").Envar("HUGGINGFACE_TOKEN").String()
|
||||
|
||||
huggingfaceIncludeModels = huggingfaceScan.Flag("include-models", "Models to include in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
|
||||
huggingfaceIncludeSpaces = huggingfaceScan.Flag("include-spaces", "Spaces to include in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
|
||||
huggingfaceIncludeDatasets = huggingfaceScan.Flag("include-datasets", "Datasets to include in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
|
||||
huggingfaceIgnoreModels = huggingfaceScan.Flag("ignore-models", "Models to ignore in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
|
||||
huggingfaceIgnoreSpaces = huggingfaceScan.Flag("ignore-spaces", "Spaces to ignore in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
|
||||
huggingfaceIgnoreDatasets = huggingfaceScan.Flag("ignore-datasets", "Datasets to ignore in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
|
||||
huggingfaceSkipAllModels = huggingfaceScan.Flag("skip-all-models", "Skip all model scans. (Only used with --user or --org)").Bool()
|
||||
huggingfaceSkipAllSpaces = huggingfaceScan.Flag("skip-all-spaces", "Skip all space scans. (Only used with --user or --org)").Bool()
|
||||
huggingfaceSkipAllDatasets = huggingfaceScan.Flag("skip-all-datasets", "Skip all dataset scans. (Only used with --user or --org)").Bool()
|
||||
huggingfaceIncludeDiscussions = huggingfaceScan.Flag("include-discussions", "Include discussions in scan.").Bool()
|
||||
huggingfaceIncludePrs = huggingfaceScan.Flag("include-prs", "Include pull requests in scan.").Bool()
|
||||
|
||||
usingTUI = false
|
||||
)
|
||||
|
||||
|
@ -738,6 +759,39 @@ func runSingleScan(ctx context.Context, cmd string, cfg engine.Config) (metrics,
|
|||
if err := eng.ScanJenkins(ctx, cfg); err != nil {
|
||||
return scanMetrics, fmt.Errorf("failed to scan Jenkins: %v", err)
|
||||
}
|
||||
case huggingfaceScan.FullCommand():
|
||||
if *huggingfaceEndpoint != "" {
|
||||
*huggingfaceEndpoint = strings.TrimRight(*huggingfaceEndpoint, "/")
|
||||
}
|
||||
|
||||
if len(*huggingfaceModels) == 0 && len(*huggingfaceSpaces) == 0 && len(*huggingfaceDatasets) == 0 && len(*huggingfaceOrgs) == 0 && len(*huggingfaceUsers) == 0 {
|
||||
return scanMetrics, fmt.Errorf("invalid config: you must specify at least one organization, user, model, space or dataset")
|
||||
}
|
||||
|
||||
cfg := engine.HuggingfaceConfig{
|
||||
Endpoint: *huggingfaceEndpoint,
|
||||
Models: *huggingfaceModels,
|
||||
Spaces: *huggingfaceSpaces,
|
||||
Datasets: *huggingfaceDatasets,
|
||||
Organizations: *huggingfaceOrgs,
|
||||
Users: *huggingfaceUsers,
|
||||
Token: *huggingfaceToken,
|
||||
IncludeModels: *huggingfaceIncludeModels,
|
||||
IncludeSpaces: *huggingfaceIncludeSpaces,
|
||||
IncludeDatasets: *huggingfaceIncludeDatasets,
|
||||
IgnoreModels: *huggingfaceIgnoreModels,
|
||||
IgnoreSpaces: *huggingfaceIgnoreSpaces,
|
||||
IgnoreDatasets: *huggingfaceIgnoreDatasets,
|
||||
SkipAllModels: *huggingfaceSkipAllModels,
|
||||
SkipAllSpaces: *huggingfaceSkipAllSpaces,
|
||||
SkipAllDatasets: *huggingfaceSkipAllDatasets,
|
||||
IncludeDiscussions: *huggingfaceIncludeDiscussions,
|
||||
IncludePrs: *huggingfaceIncludePrs,
|
||||
Concurrency: *concurrency,
|
||||
}
|
||||
if err := eng.ScanHuggingface(ctx, cfg); err != nil {
|
||||
return scanMetrics, fmt.Errorf("failed to scan HuggingFace: %v", err)
|
||||
}
|
||||
default:
|
||||
return scanMetrics, fmt.Errorf("invalid command: %s", cmd)
|
||||
}
|
||||
|
|
80
pkg/engine/huggingface.go
Normal file
80
pkg/engine/huggingface.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/huggingface"
|
||||
)
|
||||
|
||||
// HuggingFaceConfig represents the configuration for HuggingFace.
|
||||
type HuggingfaceConfig struct {
|
||||
Endpoint string
|
||||
Models []string
|
||||
Spaces []string
|
||||
Datasets []string
|
||||
Organizations []string
|
||||
Users []string
|
||||
IncludeModels []string
|
||||
IgnoreModels []string
|
||||
IncludeSpaces []string
|
||||
IgnoreSpaces []string
|
||||
IncludeDatasets []string
|
||||
IgnoreDatasets []string
|
||||
SkipAllModels bool
|
||||
SkipAllSpaces bool
|
||||
SkipAllDatasets bool
|
||||
IncludeDiscussions bool
|
||||
IncludePrs bool
|
||||
Token string
|
||||
Concurrency int
|
||||
}
|
||||
|
||||
// ScanGitHub scans HuggingFace with the provided options.
|
||||
func (e *Engine) ScanHuggingface(ctx context.Context, c HuggingfaceConfig) error {
|
||||
connection := sourcespb.Huggingface{
|
||||
Endpoint: c.Endpoint,
|
||||
Models: c.Models,
|
||||
Spaces: c.Spaces,
|
||||
Datasets: c.Datasets,
|
||||
Organizations: c.Organizations,
|
||||
Users: c.Users,
|
||||
IncludeModels: c.IncludeModels,
|
||||
IgnoreModels: c.IgnoreModels,
|
||||
IncludeSpaces: c.IncludeSpaces,
|
||||
IgnoreSpaces: c.IgnoreSpaces,
|
||||
IncludeDatasets: c.IncludeDatasets,
|
||||
IgnoreDatasets: c.IgnoreDatasets,
|
||||
SkipAllModels: c.SkipAllModels,
|
||||
SkipAllSpaces: c.SkipAllSpaces,
|
||||
SkipAllDatasets: c.SkipAllDatasets,
|
||||
IncludeDiscussions: c.IncludeDiscussions,
|
||||
IncludePrs: c.IncludePrs,
|
||||
}
|
||||
if len(c.Token) > 0 {
|
||||
connection.Credential = &sourcespb.Huggingface_Token{
|
||||
Token: c.Token,
|
||||
}
|
||||
} else {
|
||||
connection.Credential = &sourcespb.Huggingface_Unauthenticated{}
|
||||
}
|
||||
|
||||
var conn anypb.Any
|
||||
err := anypb.MarshalFrom(&conn, &connection, proto.MarshalOptions{})
|
||||
if err != nil {
|
||||
ctx.Logger().Error(err, "failed to marshal huggingface connection")
|
||||
return err
|
||||
}
|
||||
|
||||
sourceName := "trufflehog - huggingface"
|
||||
sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE)
|
||||
|
||||
huggingfaceSource := &huggingface.Source{}
|
||||
if err := huggingfaceSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = e.sourceManager.Run(ctx, sourceName, huggingfaceSource)
|
||||
return err
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1493,6 +1493,125 @@ var _ interface {
|
|||
ErrorName() string
|
||||
} = GCSValidationError{}
|
||||
|
||||
// Validate checks the field values on Huggingface with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, the first
|
||||
// error encountered is returned, or nil if there are no violations.
|
||||
func (m *Huggingface) Validate() error {
|
||||
return m.validate(false)
|
||||
}
|
||||
|
||||
// ValidateAll checks the field values on Huggingface with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, the
|
||||
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
|
||||
// nil if none found.
|
||||
func (m *Huggingface) ValidateAll() error {
|
||||
return m.validate(true)
|
||||
}
|
||||
|
||||
func (m *Huggingface) validate(all bool) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errors []error
|
||||
|
||||
// no validation rules for Link
|
||||
|
||||
// no validation rules for Username
|
||||
|
||||
// no validation rules for Repository
|
||||
|
||||
// no validation rules for Commit
|
||||
|
||||
// no validation rules for Email
|
||||
|
||||
// no validation rules for File
|
||||
|
||||
// no validation rules for Timestamp
|
||||
|
||||
// no validation rules for Line
|
||||
|
||||
// no validation rules for Visibility
|
||||
|
||||
// no validation rules for ResourceType
|
||||
|
||||
if len(errors) > 0 {
|
||||
return HuggingfaceMultiError(errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HuggingfaceMultiError is an error wrapping multiple validation errors
|
||||
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
|
||||
type HuggingfaceMultiError []error
|
||||
|
||||
// Error returns a concatenation of all the error messages it wraps.
|
||||
func (m HuggingfaceMultiError) Error() string {
|
||||
var msgs []string
|
||||
for _, err := range m {
|
||||
msgs = append(msgs, err.Error())
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
// AllErrors returns a list of validation violation errors.
|
||||
func (m HuggingfaceMultiError) AllErrors() []error { return m }
|
||||
|
||||
// HuggingfaceValidationError is the validation error returned by
|
||||
// Huggingface.Validate if the designated constraints aren't met.
|
||||
type HuggingfaceValidationError struct {
|
||||
field string
|
||||
reason string
|
||||
cause error
|
||||
key bool
|
||||
}
|
||||
|
||||
// Field function returns field value.
|
||||
func (e HuggingfaceValidationError) Field() string { return e.field }
|
||||
|
||||
// Reason function returns reason value.
|
||||
func (e HuggingfaceValidationError) Reason() string { return e.reason }
|
||||
|
||||
// Cause function returns cause value.
|
||||
func (e HuggingfaceValidationError) Cause() error { return e.cause }
|
||||
|
||||
// Key function returns key value.
|
||||
func (e HuggingfaceValidationError) Key() bool { return e.key }
|
||||
|
||||
// ErrorName returns error name.
|
||||
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
|
||||
|
||||
// Error satisfies the builtin error interface
|
||||
func (e HuggingfaceValidationError) Error() string {
|
||||
cause := ""
|
||||
if e.cause != nil {
|
||||
cause = fmt.Sprintf(" | caused by: %v", e.cause)
|
||||
}
|
||||
|
||||
key := ""
|
||||
if e.key {
|
||||
key = "key for "
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"invalid %sHuggingface.%s: %s%s",
|
||||
key,
|
||||
e.field,
|
||||
e.reason,
|
||||
cause)
|
||||
}
|
||||
|
||||
var _ error = HuggingfaceValidationError{}
|
||||
|
||||
var _ interface {
|
||||
Field() string
|
||||
Reason() string
|
||||
Key() bool
|
||||
Cause() error
|
||||
ErrorName() string
|
||||
} = HuggingfaceValidationError{}
|
||||
|
||||
// Validate checks the field values on Jira with the rules defined in the proto
|
||||
// definition for this message. If any rules are violated, the first error
|
||||
// encountered is returned, or nil if there are no violations.
|
||||
|
@ -5074,6 +5193,47 @@ func (m *MetaData) validate(all bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
case *MetaData_Huggingface:
|
||||
if v == nil {
|
||||
err := MetaDataValidationError{
|
||||
field: "Data",
|
||||
reason: "oneof value cannot be a typed-nil",
|
||||
}
|
||||
if !all {
|
||||
return err
|
||||
}
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if all {
|
||||
switch v := interface{}(m.GetHuggingface()).(type) {
|
||||
case interface{ ValidateAll() error }:
|
||||
if err := v.ValidateAll(); err != nil {
|
||||
errors = append(errors, MetaDataValidationError{
|
||||
field: "Huggingface",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
})
|
||||
}
|
||||
case interface{ Validate() error }:
|
||||
if err := v.Validate(); err != nil {
|
||||
errors = append(errors, MetaDataValidationError{
|
||||
field: "Huggingface",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if v, ok := interface{}(m.GetHuggingface()).(interface{ Validate() error }); ok {
|
||||
if err := v.Validate(); err != nil {
|
||||
return MetaDataValidationError{
|
||||
field: "Huggingface",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
_ = v // ensures v is used
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2875,6 +2875,185 @@ var _ interface {
|
|||
ErrorName() string
|
||||
} = GoogleDriveValidationError{}
|
||||
|
||||
// Validate checks the field values on Huggingface with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, the first
|
||||
// error encountered is returned, or nil if there are no violations.
|
||||
func (m *Huggingface) Validate() error {
|
||||
return m.validate(false)
|
||||
}
|
||||
|
||||
// ValidateAll checks the field values on Huggingface with the rules defined in
|
||||
// the proto definition for this message. If any rules are violated, the
|
||||
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
|
||||
// nil if none found.
|
||||
func (m *Huggingface) ValidateAll() error {
|
||||
return m.validate(true)
|
||||
}
|
||||
|
||||
func (m *Huggingface) validate(all bool) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errors []error
|
||||
|
||||
if _, err := url.Parse(m.GetEndpoint()); err != nil {
|
||||
err = HuggingfaceValidationError{
|
||||
field: "Endpoint",
|
||||
reason: "value must be a valid URI",
|
||||
cause: err,
|
||||
}
|
||||
if !all {
|
||||
return err
|
||||
}
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
// no validation rules for SkipAllModels
|
||||
|
||||
// no validation rules for SkipAllSpaces
|
||||
|
||||
// no validation rules for SkipAllDatasets
|
||||
|
||||
// no validation rules for IncludeDiscussions
|
||||
|
||||
// no validation rules for IncludePrs
|
||||
|
||||
switch v := m.Credential.(type) {
|
||||
case *Huggingface_Token:
|
||||
if v == nil {
|
||||
err := HuggingfaceValidationError{
|
||||
field: "Credential",
|
||||
reason: "oneof value cannot be a typed-nil",
|
||||
}
|
||||
if !all {
|
||||
return err
|
||||
}
|
||||
errors = append(errors, err)
|
||||
}
|
||||
// no validation rules for Token
|
||||
case *Huggingface_Unauthenticated:
|
||||
if v == nil {
|
||||
err := HuggingfaceValidationError{
|
||||
field: "Credential",
|
||||
reason: "oneof value cannot be a typed-nil",
|
||||
}
|
||||
if !all {
|
||||
return err
|
||||
}
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if all {
|
||||
switch v := interface{}(m.GetUnauthenticated()).(type) {
|
||||
case interface{ ValidateAll() error }:
|
||||
if err := v.ValidateAll(); err != nil {
|
||||
errors = append(errors, HuggingfaceValidationError{
|
||||
field: "Unauthenticated",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
})
|
||||
}
|
||||
case interface{ Validate() error }:
|
||||
if err := v.Validate(); err != nil {
|
||||
errors = append(errors, HuggingfaceValidationError{
|
||||
field: "Unauthenticated",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if v, ok := interface{}(m.GetUnauthenticated()).(interface{ Validate() error }); ok {
|
||||
if err := v.Validate(); err != nil {
|
||||
return HuggingfaceValidationError{
|
||||
field: "Unauthenticated",
|
||||
reason: "embedded message failed validation",
|
||||
cause: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
_ = v // ensures v is used
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return HuggingfaceMultiError(errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HuggingfaceMultiError is an error wrapping multiple validation errors
|
||||
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
|
||||
type HuggingfaceMultiError []error
|
||||
|
||||
// Error returns a concatenation of all the error messages it wraps.
|
||||
func (m HuggingfaceMultiError) Error() string {
|
||||
var msgs []string
|
||||
for _, err := range m {
|
||||
msgs = append(msgs, err.Error())
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
// AllErrors returns a list of validation violation errors.
|
||||
func (m HuggingfaceMultiError) AllErrors() []error { return m }
|
||||
|
||||
// HuggingfaceValidationError is the validation error returned by
|
||||
// Huggingface.Validate if the designated constraints aren't met.
|
||||
type HuggingfaceValidationError struct {
|
||||
field string
|
||||
reason string
|
||||
cause error
|
||||
key bool
|
||||
}
|
||||
|
||||
// Field function returns field value.
|
||||
func (e HuggingfaceValidationError) Field() string { return e.field }
|
||||
|
||||
// Reason function returns reason value.
|
||||
func (e HuggingfaceValidationError) Reason() string { return e.reason }
|
||||
|
||||
// Cause function returns cause value.
|
||||
func (e HuggingfaceValidationError) Cause() error { return e.cause }
|
||||
|
||||
// Key function returns key value.
|
||||
func (e HuggingfaceValidationError) Key() bool { return e.key }
|
||||
|
||||
// ErrorName returns error name.
|
||||
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
|
||||
|
||||
// Error satisfies the builtin error interface
|
||||
func (e HuggingfaceValidationError) Error() string {
|
||||
cause := ""
|
||||
if e.cause != nil {
|
||||
cause = fmt.Sprintf(" | caused by: %v", e.cause)
|
||||
}
|
||||
|
||||
key := ""
|
||||
if e.key {
|
||||
key = "key for "
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"invalid %sHuggingface.%s: %s%s",
|
||||
key,
|
||||
e.field,
|
||||
e.reason,
|
||||
cause)
|
||||
}
|
||||
|
||||
var _ error = HuggingfaceValidationError{}
|
||||
|
||||
var _ interface {
|
||||
Field() string
|
||||
Reason() string
|
||||
Key() bool
|
||||
Cause() error
|
||||
ErrorName() string
|
||||
} = HuggingfaceValidationError{}
|
||||
|
||||
// Validate checks the field values on JIRA with the rules defined in the proto
|
||||
// definition for this message. If any rules are violated, the first error
|
||||
// encountered is returned, or nil if there are no violations.
|
||||
|
|
223
pkg/sources/huggingface/client.go
Normal file
223
pkg/sources/huggingface/client.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package huggingface
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
)
|
||||
|
||||
// Maps for API and HTML paths
|
||||
var apiPaths = map[string]string{
|
||||
DATASET: DatasetsRoute,
|
||||
MODEL: ModelsAPIRoute,
|
||||
SPACE: SpacesRoute,
|
||||
}
|
||||
|
||||
var htmlPaths = map[string]string{
|
||||
DATASET: DatasetsRoute,
|
||||
MODEL: "",
|
||||
SPACE: SpacesRoute,
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
Username string `json:"name"`
|
||||
}
|
||||
|
||||
type Latest struct {
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
Latest Latest `json:"latest"`
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string `json:"type"`
|
||||
Author Author `json:"author"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
Data Data `json:"data"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (e Event) GetAuthor() string {
|
||||
return e.Author.Username
|
||||
}
|
||||
|
||||
func (e Event) GetCreatedAt() string {
|
||||
return e.CreatedAt
|
||||
}
|
||||
|
||||
func (e Event) GetID() string {
|
||||
return e.ID
|
||||
}
|
||||
|
||||
type RepoData struct {
|
||||
FullName string `json:"name"`
|
||||
ResourceType string `json:"type"`
|
||||
}
|
||||
|
||||
type Discussion struct {
|
||||
ID int `json:"num"`
|
||||
IsPR bool `json:"isPullRequest"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
Title string `json:"title"`
|
||||
Events []Event `json:"events"`
|
||||
Repo RepoData `json:"repo"`
|
||||
}
|
||||
|
||||
func (d Discussion) GetID() string {
|
||||
return fmt.Sprint(d.ID)
|
||||
}
|
||||
|
||||
func (d Discussion) GetTitle() string {
|
||||
return d.Title
|
||||
}
|
||||
|
||||
func (d Discussion) GetCreatedAt() string {
|
||||
return d.CreatedAt
|
||||
}
|
||||
|
||||
func (d Discussion) GetRepo() string {
|
||||
return d.Repo.FullName
|
||||
}
|
||||
|
||||
// GetDiscussionPath returns the path (ex: "/models/user/repo/discussions/1") for the discussion
|
||||
func (d Discussion) GetDiscussionPath() string {
|
||||
basePath := fmt.Sprintf("%s/%s/%s", d.GetRepo(), DiscussionsRoute, d.GetID())
|
||||
if d.Repo.ResourceType == "model" {
|
||||
return basePath
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
|
||||
}
|
||||
|
||||
// GetGitPath returns the path (ex: "/models/user/repo.git") for the repo's git directory
|
||||
func (d Discussion) GetGitPath() string {
|
||||
basePath := fmt.Sprintf("%s.git", d.GetRepo())
|
||||
if d.Repo.ResourceType == "model" {
|
||||
return basePath
|
||||
}
|
||||
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
|
||||
}
|
||||
|
||||
type DiscussionList struct {
|
||||
Discussions []Discussion `json:"discussions"`
|
||||
}
|
||||
|
||||
type Repo struct {
|
||||
IsPrivate bool `json:"private"`
|
||||
Owner string `json:"author"`
|
||||
RepoID string `json:"id"`
|
||||
}
|
||||
|
||||
type HFClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new API client
|
||||
func NewHFClient(baseURL, apiKey string, timeout time.Duration) *HFClient {
|
||||
return &HFClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// get makes a GET request to the Hugging Face API
|
||||
// Note: not addressing rate limit, since it seems very permissive. (ex: "If \
|
||||
// your account suddenly sends 10k requests then you’re likely to receive 503")
|
||||
func (c *HFClient) get(ctx context.Context, url string, target interface{}) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HuggingFace API request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to make request to HuggingFace API: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return errors.New("invalid API key.")
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return errors.New("access to this repo is restricted and you are not in the authorized list. Visit the repository to ask for access.")
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(target)
|
||||
}
|
||||
|
||||
// GetRepo retrieves repo from the Hugging Face API
|
||||
func (c *HFClient) GetRepo(ctx context.Context, repoName string, resourceType string) (Repo, error) {
|
||||
var repo Repo
|
||||
url, err := buildAPIURL(c.BaseURL, resourceType, repoName)
|
||||
if err != nil {
|
||||
return repo, err
|
||||
}
|
||||
err = c.get(ctx, url, &repo)
|
||||
return repo, err
|
||||
}
|
||||
|
||||
// ListDiscussions retrieves discussions from the Hugging Face API
|
||||
func (c *HFClient) ListDiscussions(ctx context.Context, repoInfo repoInfo) (DiscussionList, error) {
|
||||
var discussions DiscussionList
|
||||
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
|
||||
if err != nil {
|
||||
return discussions, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", baseURL, DiscussionsRoute)
|
||||
err = c.get(ctx, url, &discussions)
|
||||
return discussions, err
|
||||
}
|
||||
|
||||
func (c *HFClient) GetDiscussionByID(ctx context.Context, repoInfo repoInfo, discussionID string) (Discussion, error) {
|
||||
var discussion Discussion
|
||||
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
|
||||
if err != nil {
|
||||
return discussion, err
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/%s", baseURL, DiscussionsRoute, discussionID)
|
||||
err = c.get(ctx, url, &discussion)
|
||||
return discussion, err
|
||||
}
|
||||
|
||||
// ListReposByAuthor retrieves repos from the Hugging Face API by author (user or org)
|
||||
// Note: not addressing pagination b/c allow by default 1000 results, which should be enough for 99.99% of cases
|
||||
func (c *HFClient) ListReposByAuthor(ctx context.Context, resourceType string, author string) ([]Repo, error) {
|
||||
var repos []Repo
|
||||
url := fmt.Sprintf("%s/%s/%s?limit=1000&author=%s", c.BaseURL, APIRoute, getResourceAPIPath(resourceType), author)
|
||||
err := c.get(ctx, url, &repos)
|
||||
return repos, err
|
||||
}
|
||||
|
||||
// getResourceAPIPath returns the API path for the given resource type
|
||||
func getResourceAPIPath(resourceType string) string {
|
||||
return apiPaths[resourceType]
|
||||
}
|
||||
|
||||
// getResourceHTMLPath returns the HTML path for the given resource type
|
||||
func getResourceHTMLPath(resourceType string) string {
|
||||
return htmlPaths[resourceType]
|
||||
}
|
||||
|
||||
func buildAPIURL(endpoint string, resourceType string, repoName string) (string, error) {
|
||||
if endpoint == "" || resourceType == "" || repoName == "" {
|
||||
return "", errors.New("endpoint, resourceType, and repoName must not be empty")
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s/%s", endpoint, APIRoute, getResourceAPIPath(resourceType), repoName), nil
|
||||
}
|
680
pkg/sources/huggingface/huggingface.go
Normal file
680
pkg/sources/huggingface/huggingface.go
Normal file
|
@ -0,0 +1,680 @@
|
|||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/gobwas/glob"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
|
||||
)
|
||||
|
||||
const (
|
||||
SourceType = sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE
|
||||
DatasetsRoute = "datasets"
|
||||
SpacesRoute = "spaces"
|
||||
ModelsAPIRoute = "models"
|
||||
DiscussionsRoute = "discussions"
|
||||
APIRoute = "api"
|
||||
DATASET = "dataset"
|
||||
MODEL = "model"
|
||||
SPACE = "space"
|
||||
defaultPagination = 100
|
||||
)
|
||||
|
||||
type resourceType string
|
||||
|
||||
type Source struct {
|
||||
name string
|
||||
huggingfaceToken string
|
||||
|
||||
sourceID sources.SourceID
|
||||
jobID sources.JobID
|
||||
verify bool
|
||||
useCustomContentWriter bool
|
||||
orgsCache cache.Cache[string]
|
||||
usersCache cache.Cache[string]
|
||||
|
||||
models []string
|
||||
spaces []string
|
||||
datasets []string
|
||||
|
||||
filteredModelsCache *filteredRepoCache
|
||||
filteredSpacesCache *filteredRepoCache
|
||||
filteredDatasetsCache *filteredRepoCache
|
||||
|
||||
repoInfoCache repoInfoCache
|
||||
|
||||
git *git.Git
|
||||
|
||||
scanOptions *git.ScanOptions
|
||||
|
||||
apiClient *HFClient
|
||||
log logr.Logger
|
||||
conn *sourcespb.Huggingface
|
||||
jobPool *errgroup.Group
|
||||
resumeInfoMutex sync.Mutex
|
||||
resumeInfoSlice []string
|
||||
|
||||
skipAllModels bool
|
||||
skipAllSpaces bool
|
||||
skipAllDatasets bool
|
||||
includeDiscussions bool
|
||||
includePrs bool
|
||||
|
||||
sources.Progress
|
||||
sources.CommonSourceUnitUnmarshaller
|
||||
}
|
||||
|
||||
// Ensure the Source satisfies the interfaces at compile time
|
||||
var _ sources.Source = (*Source)(nil)
|
||||
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
|
||||
|
||||
// WithCustomContentWriter sets the useCustomContentWriter flag on the source.
|
||||
func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true }
|
||||
|
||||
// Type returns the type of source.
|
||||
// It is used for matching source types in configuration and job input.
|
||||
func (s *Source) Type() sourcespb.SourceType {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) SourceID() sources.SourceID {
|
||||
return s.sourceID
|
||||
}
|
||||
|
||||
func (s *Source) JobID() sources.JobID {
|
||||
return s.jobID
|
||||
}
|
||||
|
||||
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
|
||||
// based on include and exclude globs.
|
||||
type filteredRepoCache struct {
|
||||
cache.Cache[string]
|
||||
include, exclude []glob.Glob
|
||||
}
|
||||
|
||||
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
|
||||
includeGlobs := make([]glob.Glob, 0, len(include))
|
||||
excludeGlobs := make([]glob.Glob, 0, len(exclude))
|
||||
for _, ig := range include {
|
||||
g, err := glob.Compile(ig)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
|
||||
continue
|
||||
}
|
||||
includeGlobs = append(includeGlobs, g)
|
||||
}
|
||||
for _, eg := range exclude {
|
||||
g, err := glob.Compile(eg)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
|
||||
continue
|
||||
}
|
||||
excludeGlobs = append(excludeGlobs, g)
|
||||
}
|
||||
return &filteredRepoCache{Cache: c, include: includeGlobs, exclude: excludeGlobs}
|
||||
}
|
||||
|
||||
// Set overrides the cache.Cache Set method to filter out repos based on
|
||||
// include and exclude globs.
|
||||
func (c *filteredRepoCache) Set(key, val string) {
|
||||
if c.ignoreRepo(key) {
|
||||
return
|
||||
}
|
||||
if !c.includeRepo(key) {
|
||||
return
|
||||
}
|
||||
c.Cache.Set(key, val)
|
||||
}
|
||||
|
||||
func (c *filteredRepoCache) ignoreRepo(s string) bool {
|
||||
for _, g := range c.exclude {
|
||||
if g.Match(s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *filteredRepoCache) includeRepo(s string) bool {
|
||||
if len(c.include) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, g := range c.include {
|
||||
if g.Match(s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Init returns an initialized HuggingFace source.
|
||||
func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
|
||||
err := git.CmdCheck()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.log = aCtx.Logger()
|
||||
|
||||
s.name = name
|
||||
s.sourceID = sourceID
|
||||
s.jobID = jobID
|
||||
s.verify = verify
|
||||
s.jobPool = &errgroup.Group{}
|
||||
s.jobPool.SetLimit(concurrency)
|
||||
|
||||
var conn sourcespb.Huggingface
|
||||
err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshalling connection: %w", err)
|
||||
}
|
||||
s.conn = &conn
|
||||
|
||||
s.orgsCache = memory.New[string]()
|
||||
for _, org := range s.conn.Organizations {
|
||||
s.orgsCache.Set(org, org)
|
||||
}
|
||||
|
||||
s.usersCache = memory.New[string]()
|
||||
for _, user := range s.conn.Users {
|
||||
s.usersCache.Set(user, user)
|
||||
}
|
||||
|
||||
//Verify ignore and include models, spaces, and datasets are valid
|
||||
// this ensures that calling --org <org> --ignore-model <org/model> contains the proper
|
||||
// repo format of org/model. Otherwise, we would scan the entire org.
|
||||
if err := s.validateIgnoreIncludeRepos(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.filteredModelsCache = s.newFilteredRepoCache(memory.New[string](),
|
||||
append(s.conn.GetModels(), s.conn.GetIncludeModels()...),
|
||||
s.conn.GetIgnoreModels(),
|
||||
)
|
||||
|
||||
s.filteredSpacesCache = s.newFilteredRepoCache(memory.New[string](),
|
||||
append(s.conn.GetSpaces(), s.conn.GetIncludeSpaces()...),
|
||||
s.conn.GetIgnoreSpaces(),
|
||||
)
|
||||
|
||||
s.filteredDatasetsCache = s.newFilteredRepoCache(memory.New[string](),
|
||||
append(s.conn.GetDatasets(), s.conn.GetIncludeDatasets()...),
|
||||
s.conn.GetIgnoreDatasets(),
|
||||
)
|
||||
|
||||
s.models = initializeRepos(s.filteredModelsCache, s.conn.Models, fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s"))
|
||||
s.spaces = initializeRepos(s.filteredSpacesCache, s.conn.Spaces, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s"))
|
||||
s.datasets = initializeRepos(s.filteredDatasetsCache, s.conn.Datasets, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s"))
|
||||
s.repoInfoCache = newRepoInfoCache()
|
||||
|
||||
s.includeDiscussions = s.conn.IncludeDiscussions
|
||||
s.includePrs = s.conn.IncludePrs
|
||||
|
||||
cfg := &git.Config{
|
||||
SourceName: s.name,
|
||||
JobID: s.jobID,
|
||||
SourceID: s.sourceID,
|
||||
SourceType: s.Type(),
|
||||
Verify: s.verify,
|
||||
Concurrency: concurrency,
|
||||
SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData {
|
||||
return &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Huggingface{
|
||||
Huggingface: &source_metadatapb.Huggingface{
|
||||
Commit: sanitizer.UTF8(commit),
|
||||
File: sanitizer.UTF8(file),
|
||||
Email: sanitizer.UTF8(email),
|
||||
Repository: sanitizer.UTF8(repository),
|
||||
Link: giturl.GenerateLink(repository, commit, file, line),
|
||||
Timestamp: sanitizer.UTF8(timestamp),
|
||||
Line: line,
|
||||
Visibility: s.visibilityOf(aCtx, repository),
|
||||
ResourceType: s.getResourceType(aCtx, repository),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
UseCustomContentWriter: s.useCustomContentWriter,
|
||||
}
|
||||
s.git = git.NewGit(cfg)
|
||||
|
||||
s.huggingfaceToken = s.conn.GetToken()
|
||||
s.apiClient = NewHFClient(s.conn.Endpoint, s.huggingfaceToken, 10*time.Second)
|
||||
|
||||
s.skipAllModels = s.conn.SkipAllModels
|
||||
s.skipAllSpaces = s.conn.SkipAllSpaces
|
||||
s.skipAllDatasets = s.conn.SkipAllDatasets
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) validateIgnoreIncludeRepos() error {
|
||||
if err := verifySlashSeparatedStrings(s.conn.IgnoreModels); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifySlashSeparatedStrings(s.conn.IncludeModels); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifySlashSeparatedStrings(s.conn.IgnoreSpaces); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifySlashSeparatedStrings(s.conn.IncludeSpaces); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifySlashSeparatedStrings(s.conn.IgnoreDatasets); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := verifySlashSeparatedStrings(s.conn.IncludeDatasets); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifySlashSeparatedStrings(s []string) error {
|
||||
for _, str := range s {
|
||||
if !strings.Contains(str, "/") {
|
||||
return fmt.Errorf("invalid owner/repo: %s", str)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initializeRepos(cache *filteredRepoCache, repos []string, urlPattern string) []string {
|
||||
returnRepos := make([]string, 0)
|
||||
for _, repo := range repos {
|
||||
if !cache.ignoreRepo(repo) {
|
||||
url := fmt.Sprintf(urlPattern, repo)
|
||||
cache.Set(repo, url)
|
||||
returnRepos = append(returnRepos, repo)
|
||||
}
|
||||
}
|
||||
return returnRepos
|
||||
}
|
||||
|
||||
func (s *Source) getResourceType(ctx context.Context, repoURL string) string {
|
||||
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||
if !ok {
|
||||
// This should never happen.
|
||||
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||
ctx.Logger().Error(err, "failed to get repository resource type")
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(repoInfo.resourceType)
|
||||
}
|
||||
|
||||
func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metadatapb.Visibility {
|
||||
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||
if !ok {
|
||||
// This should never happen.
|
||||
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||
ctx.Logger().Error(err, "failed to get repository visibility")
|
||||
return source_metadatapb.Visibility_unknown
|
||||
}
|
||||
|
||||
return repoInfo.visibility
|
||||
}
|
||||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
|
||||
err := s.enumerate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.scan(ctx, chunksChan)
|
||||
}
|
||||
|
||||
func (s *Source) enumerate(ctx context.Context) error {
|
||||
s.enumerateAuthors(ctx)
|
||||
|
||||
s.models = make([]string, 0, s.filteredModelsCache.Count())
|
||||
for _, repo := range s.filteredModelsCache.Keys() {
|
||||
if err := s.cacheRepoInfo(ctx, repo, MODEL, s.filteredModelsCache); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.spaces = make([]string, 0, s.filteredSpacesCache.Count())
|
||||
for _, repo := range s.filteredSpacesCache.Keys() {
|
||||
if err := s.cacheRepoInfo(ctx, repo, SPACE, s.filteredSpacesCache); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.datasets = make([]string, 0, s.filteredDatasetsCache.Count())
|
||||
for _, repo := range s.filteredDatasetsCache.Keys() {
|
||||
if err := s.cacheRepoInfo(ctx, repo, DATASET, s.filteredDatasetsCache); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Info("Completed enumeration", "num_models", len(s.models), "num_spaces", len(s.spaces), "num_datasets", len(s.datasets))
|
||||
|
||||
// We must sort the repos so we can resume later if necessary.
|
||||
sort.Strings(s.models)
|
||||
sort.Strings(s.datasets)
|
||||
sort.Strings(s.spaces)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) cacheRepoInfo(ctx context.Context, repo string, repoType string, repoCache *filteredRepoCache) error {
|
||||
repoURL, _ := repoCache.Get(repo)
|
||||
repoCtx := context.WithValue(ctx, repoType, repoURL)
|
||||
|
||||
if _, ok := s.repoInfoCache.get(repoURL); !ok {
|
||||
repoCtx.Logger().V(2).Info("Caching " + repoType + " info")
|
||||
repo, err := s.apiClient.GetRepo(repoCtx, repo, repoType)
|
||||
if err != nil {
|
||||
repoCtx.Logger().Error(err, "failed to fetch "+repoType)
|
||||
return err
|
||||
}
|
||||
// check if repo empty
|
||||
if repo.RepoID == "" {
|
||||
repoCtx.Logger().Error(fmt.Errorf("no repo found for repo"), repoURL)
|
||||
return nil
|
||||
}
|
||||
s.repoInfoCache.put(repoURL, repoInfo{
|
||||
owner: repo.Owner,
|
||||
name: strings.Split(repo.RepoID, "/")[1],
|
||||
fullName: repo.RepoID,
|
||||
visibility: getVisibility(repo.IsPrivate),
|
||||
resourceType: resourceType(repoType),
|
||||
})
|
||||
}
|
||||
s.updateRepoLists(repoURL, repoType)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getVisibility(isPrivate bool) source_metadatapb.Visibility {
|
||||
if isPrivate {
|
||||
return source_metadatapb.Visibility_private
|
||||
}
|
||||
return source_metadatapb.Visibility_public
|
||||
}
|
||||
|
||||
func (s *Source) updateRepoLists(repoURL string, repoType string) {
|
||||
switch repoType {
|
||||
case MODEL:
|
||||
s.models = append(s.models, repoURL)
|
||||
case SPACE:
|
||||
s.spaces = append(s.spaces, repoURL)
|
||||
case DATASET:
|
||||
s.datasets = append(s.datasets, repoURL)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Source) fetchAndCacheRepos(ctx context.Context, resourceType string, org string) error {
|
||||
var repos []Repo
|
||||
var err error
|
||||
var url string
|
||||
var filteredCache *filteredRepoCache
|
||||
switch resourceType {
|
||||
case MODEL:
|
||||
filteredCache = s.filteredModelsCache
|
||||
url = fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s")
|
||||
repos, err = s.apiClient.ListReposByAuthor(ctx, MODEL, org)
|
||||
case SPACE:
|
||||
filteredCache = s.filteredSpacesCache
|
||||
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s")
|
||||
repos, err = s.apiClient.ListReposByAuthor(ctx, SPACE, org)
|
||||
case DATASET:
|
||||
filteredCache = s.filteredDatasetsCache
|
||||
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s")
|
||||
repos, err = s.apiClient.ListReposByAuthor(ctx, DATASET, org)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, repo := range repos {
|
||||
repoURL := fmt.Sprintf(url, repo.RepoID)
|
||||
filteredCache.Set(repo.RepoID, repoURL)
|
||||
if err := s.cacheRepoInfo(ctx, repo.RepoID, resourceType, filteredCache); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) enumerateAuthors(ctx context.Context) {
|
||||
for _, org := range s.orgsCache.Keys() {
|
||||
orgCtx := context.WithValue(ctx, "organization", org)
|
||||
if !s.skipAllModels {
|
||||
if err := s.fetchAndCacheRepos(orgCtx, MODEL, org); err != nil {
|
||||
orgCtx.Logger().Error(err, "Failed to fetch models for organization")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !s.skipAllSpaces {
|
||||
if err := s.fetchAndCacheRepos(orgCtx, SPACE, org); err != nil {
|
||||
orgCtx.Logger().Error(err, "Failed to fetch spaces for organization")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !s.skipAllDatasets {
|
||||
if err := s.fetchAndCacheRepos(orgCtx, DATASET, org); err != nil {
|
||||
orgCtx.Logger().Error(err, "Failed to fetch datasets for organization")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, user := range s.usersCache.Keys() {
|
||||
userCtx := context.WithValue(ctx, "user", user)
|
||||
if !s.skipAllModels {
|
||||
if err := s.fetchAndCacheRepos(userCtx, MODEL, user); err != nil {
|
||||
userCtx.Logger().Error(err, "Failed to fetch models for user")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !s.skipAllSpaces {
|
||||
if err := s.fetchAndCacheRepos(userCtx, SPACE, user); err != nil {
|
||||
userCtx.Logger().Error(err, "Failed to fetch spaces for user")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !s.skipAllDatasets {
|
||||
if err := s.fetchAndCacheRepos(userCtx, DATASET, user); err != nil {
|
||||
userCtx.Logger().Error(err, "Failed to fetch datasets for user")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk, resourceType string) error {
|
||||
var scannedCount uint64 = 1
|
||||
|
||||
repos := s.getReposListByType(resourceType)
|
||||
|
||||
s.log.V(2).Info("Found "+resourceType+" to scan", "count", len(repos))
|
||||
|
||||
// If there is resume information available, limit this scan to only the repos that still need scanning.
|
||||
reposToScan, progressIndexOffset := sources.FilterReposToResume(repos, s.GetProgress().EncodedResumeInfo)
|
||||
repos = reposToScan
|
||||
|
||||
scanErrs := sources.NewScanErrors()
|
||||
|
||||
if s.scanOptions == nil {
|
||||
s.scanOptions = &git.ScanOptions{}
|
||||
}
|
||||
|
||||
for i, repoURL := range repos {
|
||||
i, repoURL := i, repoURL
|
||||
s.jobPool.Go(func() error {
|
||||
if common.IsDone(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: set progress complete is being called concurrently with i
|
||||
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL, resourceType, repos)
|
||||
// Ensure the repo is removed from the resume info after being scanned.
|
||||
defer func(s *Source, repoURL string) {
|
||||
s.resumeInfoMutex.Lock()
|
||||
defer s.resumeInfoMutex.Unlock()
|
||||
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
|
||||
}(s, repoURL)
|
||||
|
||||
// Scan the repository
|
||||
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||
if !ok {
|
||||
// This should never happen.
|
||||
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||
s.log.Error(err, "failed to scan "+resourceType)
|
||||
return nil
|
||||
}
|
||||
repoCtx := context.WithValues(ctx, resourceType, repoURL)
|
||||
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan)
|
||||
if err != nil {
|
||||
scanErrs.Add(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan discussions and PRs, if enabled.
|
||||
if s.includeDiscussions || s.includePrs {
|
||||
if err = s.scanDiscussions(repoCtx, repoInfo, chunksChan); err != nil {
|
||||
scanErrs.Add(fmt.Errorf("error scanning discussions/PRs in repo %s: %w", repoURL, err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
repoCtx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d "+resourceType+"s", scannedCount, len(s.models)), "duration_seconds", duration)
|
||||
atomic.AddUint64(&scannedCount, 1)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = s.jobPool.Wait()
|
||||
if scanErrs.Count() > 0 {
|
||||
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
|
||||
}
|
||||
s.SetProgressComplete(len(repos), len(repos), "Completed HuggingFace "+resourceType+" scan", "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) getReposListByType(resourceType string) []string {
|
||||
switch resourceType {
|
||||
case MODEL:
|
||||
return s.models
|
||||
case SPACE:
|
||||
return s.spaces
|
||||
case DATASET:
|
||||
return s.datasets
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||
if err := s.scanRepos(ctx, chunksChan, MODEL); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.scanRepos(ctx, chunksChan, SPACE); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.scanRepos(ctx, chunksChan, DATASET); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
|
||||
ctx.Logger().V(2).Info("attempting to clone %s", repoInfo.resourceType)
|
||||
path, repo, err := s.cloneRepo(ctx, repoURL)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
var logger logr.Logger
|
||||
logger.V(2).Info("scanning %s", repoInfo.resourceType)
|
||||
|
||||
start := time.Now()
|
||||
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil {
|
||||
return 0, fmt.Errorf("error scanning repo %s: %w", repoURL, err)
|
||||
}
|
||||
return time.Since(start), nil
|
||||
}
|
||||
|
||||
// setProgressCompleteWithRepo calls the s.SetProgressComplete after safely setting up the encoded resume info string.
|
||||
func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL string, resourceType string, repos []string) {
|
||||
s.resumeInfoMutex.Lock()
|
||||
defer s.resumeInfoMutex.Unlock()
|
||||
|
||||
// Add the repoURL to the resume info slice.
|
||||
s.resumeInfoSlice = append(s.resumeInfoSlice, repoURL)
|
||||
sort.Strings(s.resumeInfoSlice)
|
||||
|
||||
// Make the resume info string from the slice.
|
||||
encodedResumeInfo := sources.EncodeResumeInfo(s.resumeInfoSlice)
|
||||
s.SetProgressComplete(index+offset, len(repos)+offset, fmt.Sprintf("%ss: %s", resourceType, repoURL), encodedResumeInfo)
|
||||
}
|
||||
|
||||
func (s *Source) scanDiscussions(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
|
||||
discussions, err := s.apiClient.ListDiscussions(ctx, repoInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, discussion := range discussions.Discussions {
|
||||
if (discussion.IsPR && s.includePrs) || (!discussion.IsPR && s.includeDiscussions) {
|
||||
d, err := s.apiClient.GetDiscussionByID(ctx, repoInfo, discussion.GetID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Note: there is no discussion "description" or similar to chunk, only comments
|
||||
if err = s.chunkDiscussionComments(ctx, repoInfo, d, chunksChan); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) chunkDiscussionComments(ctx context.Context, repoInfo repoInfo, discussion Discussion, chunksChan chan *sources.Chunk) error {
|
||||
for _, comment := range discussion.Events {
|
||||
chunk := &sources.Chunk{
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
SourceType: s.Type(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Huggingface{
|
||||
Huggingface: &source_metadatapb.Huggingface{
|
||||
Link: sanitizer.UTF8(fmt.Sprintf("%s/%s#%s", s.conn.Endpoint, discussion.GetDiscussionPath(), comment.GetID())),
|
||||
Username: sanitizer.UTF8(comment.GetAuthor()),
|
||||
Repository: sanitizer.UTF8(fmt.Sprintf("%s/%s", s.conn.Endpoint, discussion.GetGitPath())),
|
||||
Timestamp: sanitizer.UTF8(comment.GetCreatedAt()),
|
||||
Visibility: repoInfo.visibility,
|
||||
},
|
||||
},
|
||||
},
|
||||
Data: []byte(comment.Data.Latest.Raw),
|
||||
Verify: s.verify,
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case chunksChan <- chunk:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
651
pkg/sources/huggingface/huggingface_client_test.go
Normal file
651
pkg/sources/huggingface/huggingface_client_test.go
Normal file
|
@ -0,0 +1,651 @@
|
|||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
)
|
||||
|
||||
const (
|
||||
TEST_TOKEN = "test token"
|
||||
)
|
||||
|
||||
func initTestClient() *HFClient {
|
||||
return NewHFClient("https://huggingface.co", TEST_TOKEN, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestGetRepo(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
repoName := "test-model"
|
||||
repoOwner := "test-author"
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": repoOwner + "/" + repoName,
|
||||
"author": repoOwner,
|
||||
"private": true,
|
||||
})
|
||||
|
||||
client := initTestClient()
|
||||
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, model)
|
||||
assert.Equal(t, repoOwner+"/"+repoName, model.RepoID)
|
||||
assert.Equal(t, repoOwner, model.Owner)
|
||||
assert.Equal(t, true, model.IsPrivate)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetRepo_NotFound(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
repoName := "doesnotexist"
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(404).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "",
|
||||
"author": "",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
client := initTestClient()
|
||||
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, model)
|
||||
assert.Equal(t, "", model.RepoID)
|
||||
assert.Equal(t, "", model.Owner)
|
||||
assert.Equal(t, false, model.IsPrivate)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetModel_Error(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
repoName := "doesnotexist"
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(500)
|
||||
|
||||
client := initTestClient()
|
||||
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.NotNil(t, model)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListDiscussions(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/test-model",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
|
||||
jsonBlob := `{
|
||||
"discussions": [
|
||||
{
|
||||
"num": 2,
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false
|
||||
},
|
||||
"repo": {
|
||||
"name": "test-author/test-model",
|
||||
"type": "model"
|
||||
},
|
||||
"title": "new PR",
|
||||
"status": "open",
|
||||
"createdAt": "2024-06-18T14:34:21.000Z",
|
||||
"isPullRequest": true,
|
||||
"numComments": 2,
|
||||
"pinned": false
|
||||
},
|
||||
{
|
||||
"num": 1,
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false
|
||||
},
|
||||
"repo": {
|
||||
"name": "test-author/test-model",
|
||||
"type": "model"
|
||||
},
|
||||
"title": "secret in comment",
|
||||
"status": "closed",
|
||||
"createdAt": "2024-06-18T14:31:57.000Z",
|
||||
"isPullRequest": false,
|
||||
"numComments": 2,
|
||||
"pinned": false
|
||||
}
|
||||
],
|
||||
"count": 2,
|
||||
"start": 0,
|
||||
"numClosedDiscussions": 1
|
||||
}`
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(200).
|
||||
JSON(jsonBlob)
|
||||
|
||||
client := initTestClient()
|
||||
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, discussions)
|
||||
assert.Equal(t, 2, len(discussions.Discussions))
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListDiscussions_NotFound(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/doesnotexist",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(404).
|
||||
JSON(map[string]interface{}{
|
||||
"discussions": []map[string]interface{}{},
|
||||
})
|
||||
|
||||
client := initTestClient()
|
||||
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, discussions)
|
||||
assert.Equal(t, 0, len(discussions.Discussions))
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListDiscussions_Error(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/doesnotexist",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(500)
|
||||
|
||||
client := initTestClient()
|
||||
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.NotNil(t, discussions)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetDiscussionByID(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/test-model",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
discussionID := "1"
|
||||
|
||||
jsonBlob := `{
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false
|
||||
},
|
||||
"num": 1,
|
||||
"repo": {
|
||||
"name": "test-author/test-model",
|
||||
"type": "model"
|
||||
},
|
||||
"title": "secret in initial",
|
||||
"status": "open",
|
||||
"createdAt": "2024-06-18T14:31:46.000Z",
|
||||
"events": [
|
||||
{
|
||||
"id": "525",
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false,
|
||||
"isOwner": true,
|
||||
"isOrgMember": false
|
||||
},
|
||||
"createdAt": "2024-06-18T14:31:46.000Z",
|
||||
"type": "comment",
|
||||
"data": {
|
||||
"edited": true,
|
||||
"hidden": false,
|
||||
"latest": {
|
||||
"raw": "dd",
|
||||
"html": "<p>dd</p>\n",
|
||||
"updatedAt": "2024-06-18T14:33:32.066Z",
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false
|
||||
}
|
||||
},
|
||||
"numEdits": 1,
|
||||
"editors": ["trufflej"],
|
||||
"reactions": [],
|
||||
"identifiedLanguage": {
|
||||
"language": "en",
|
||||
"probability": 0.40104949474334717
|
||||
},
|
||||
"isReport": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "526",
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false,
|
||||
"isOwner": true,
|
||||
"isOrgMember": false
|
||||
},
|
||||
"createdAt": "2024-06-18T14:32:40.000Z",
|
||||
"type": "status-change",
|
||||
"data": {
|
||||
"status": "closed"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "527",
|
||||
"author": {
|
||||
"avatarUrl": "/avatars/test.svg",
|
||||
"fullname": "TEST",
|
||||
"name": "test-author",
|
||||
"type": "user",
|
||||
"isPro": false,
|
||||
"isHf": false,
|
||||
"isMod": false,
|
||||
"isOwner": true,
|
||||
"isOrgMember": false
|
||||
},
|
||||
"createdAt": "2024-06-18T14:33:27.000Z",
|
||||
"type": "status-change",
|
||||
"data": {
|
||||
"status": "open"
|
||||
}
|
||||
}
|
||||
],
|
||||
"pinned": false,
|
||||
"locked": false,
|
||||
"isPullRequest": false,
|
||||
"isReport": false
|
||||
}`
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(200).
|
||||
JSON(jsonBlob)
|
||||
|
||||
client := initTestClient()
|
||||
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, discussion)
|
||||
assert.Equal(t, discussionID, strconv.Itoa(discussion.ID))
|
||||
assert.Equal(t, 3, len(discussion.Events))
|
||||
assert.Equal(t, false, discussion.IsPR)
|
||||
assert.Equal(t, "secret in initial", discussion.Title)
|
||||
assert.Equal(t, repoInfo.fullName, discussion.Repo.FullName)
|
||||
assert.Equal(t, string(repoInfo.resourceType), discussion.Repo.ResourceType)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetDiscussionByID_NotFound(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/test-model",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
discussionID := "doesnotexist"
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(404).
|
||||
JSON(map[string]interface{}{})
|
||||
|
||||
client := initTestClient()
|
||||
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, discussion)
|
||||
assert.Equal(t, 0, len(discussion.Events))
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetDiscussionByID_Error(t *testing.T) {
|
||||
repoInfo := repoInfo{
|
||||
fullName: "test-author/test-model",
|
||||
resourceType: MODEL,
|
||||
}
|
||||
discussionID := "doesnotexist"
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(500)
|
||||
|
||||
client := initTestClient()
|
||||
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.NotNil(t, discussion)
|
||||
assert.Equal(t, "", discussion.Title)
|
||||
assert.Equal(t, 0, len(discussion.Events))
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListReposByAuthor(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
author := "test-author"
|
||||
repo := "test-model"
|
||||
repo2 := "test-model2"
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||
MatchParam("author", author).
|
||||
MatchParam("limit", "1000").
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "1",
|
||||
"id": author + "/" + repo,
|
||||
"modelId": author + "/" + repo,
|
||||
"private": true,
|
||||
},
|
||||
{
|
||||
"_id": "2",
|
||||
"id": author + "/" + repo2,
|
||||
"modelId": author + "/" + repo2,
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
for _, mock := range gock.Pending() {
|
||||
fmt.Println(mock.Request().URLStruct.String())
|
||||
}
|
||||
|
||||
client := initTestClient()
|
||||
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, repos)
|
||||
assert.Equal(t, 2, len(repos))
|
||||
// count of repos with private flag
|
||||
countOfPrivateRepos := 0
|
||||
for _, repo := range repos {
|
||||
if repo.IsPrivate {
|
||||
countOfPrivateRepos++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, countOfPrivateRepos)
|
||||
// there is no author field in JSON, so assert repo.Owner is empty
|
||||
for _, repo := range repos {
|
||||
assert.Equal(t, "", repo.Owner)
|
||||
}
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListReposByAuthor_NotFound(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
author := "authordoesntexist"
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||
MatchParam("author", author).
|
||||
MatchParam("limit", "1000").
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(404).
|
||||
JSON([]map[string]interface{}{})
|
||||
|
||||
client := initTestClient()
|
||||
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, repos)
|
||||
assert.Equal(t, 0, len(repos))
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestListReposByAuthor_Error(t *testing.T) {
|
||||
resourceType := MODEL
|
||||
author := "doesnotexist"
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||
MatchParam("author", author).
|
||||
MatchParam("limit", "1000").
|
||||
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||
Reply(500)
|
||||
|
||||
client := initTestClient()
|
||||
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, repos)
|
||||
assert.False(t, gock.HasUnmatchedRequest())
|
||||
assert.True(t, gock.IsDone())
|
||||
}
|
||||
|
||||
func TestGetResourceAPIPath(t *testing.T) {
|
||||
assert.Equal(t, "models", getResourceAPIPath(MODEL))
|
||||
assert.Equal(t, "datasets", getResourceAPIPath(DATASET))
|
||||
assert.Equal(t, "spaces", getResourceAPIPath(SPACE))
|
||||
}
|
||||
|
||||
func TestGetResourceHTMLPath(t *testing.T) {
|
||||
assert.Equal(t, "", getResourceHTMLPath(MODEL))
|
||||
assert.Equal(t, "datasets", getResourceHTMLPath(DATASET))
|
||||
assert.Equal(t, "spaces", getResourceHTMLPath(SPACE))
|
||||
}
|
||||
|
||||
func TestBuildAPIURL_ValidInputs(t *testing.T) {
|
||||
endpoint := "https://huggingface.co"
|
||||
resourceType := MODEL
|
||||
repoName := "test-repo"
|
||||
|
||||
expectedURL := "https://huggingface.co/api/models/test-repo"
|
||||
|
||||
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectedURL, url)
|
||||
}
|
||||
|
||||
func TestBuildAPIURL_EmptyEndpoint(t *testing.T) {
|
||||
endpoint := ""
|
||||
resourceType := MODEL
|
||||
repoName := "test-repo"
|
||||
|
||||
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "", url)
|
||||
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||
}
|
||||
|
||||
func TestBuildAPIURL_EmptyResourceType(t *testing.T) {
|
||||
endpoint := "https://huggingface.co"
|
||||
resourceType := ""
|
||||
repoName := "test-repo"
|
||||
|
||||
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "", url)
|
||||
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||
}
|
||||
|
||||
func TestBuildAPIURL_EmptyRepoName(t *testing.T) {
|
||||
endpoint := "https://huggingface.co"
|
||||
resourceType := "model"
|
||||
repoName := ""
|
||||
|
||||
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "", url)
|
||||
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||
}
|
||||
|
||||
func TestGetDiscussionPath_ModelResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-model",
|
||||
ResourceType: "model",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "test-author/test-model/discussions/1"
|
||||
|
||||
path := discussion.GetDiscussionPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
||||
|
||||
func TestGetDiscussionPath_DatasetResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-dataset",
|
||||
ResourceType: "dataset",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "datasets/test-author/test-dataset/discussions/1"
|
||||
|
||||
path := discussion.GetDiscussionPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
||||
|
||||
func TestGetDiscussionPath_SpaceResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-space",
|
||||
ResourceType: "space",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "spaces/test-author/test-space/discussions/1"
|
||||
|
||||
path := discussion.GetDiscussionPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
||||
|
||||
func TestGetGitPath_ModelResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-model",
|
||||
ResourceType: "model",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "test-author/test-model.git"
|
||||
|
||||
path := discussion.GetGitPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
||||
|
||||
func TestGetGitPath_DatasetResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-dataset",
|
||||
ResourceType: "dataset",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "datasets/test-author/test-dataset.git"
|
||||
|
||||
path := discussion.GetGitPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
||||
|
||||
func TestGetGitPath_SpaceResource(t *testing.T) {
|
||||
discussion := Discussion{
|
||||
Repo: RepoData{
|
||||
FullName: "test-author/test-space",
|
||||
ResourceType: "space",
|
||||
},
|
||||
ID: 1,
|
||||
}
|
||||
|
||||
expectedPath := "spaces/test-author/test-space.git"
|
||||
|
||||
path := discussion.GetGitPath()
|
||||
assert.Equal(t, expectedPath, path)
|
||||
}
|
986
pkg/sources/huggingface/huggingface_test.go
Normal file
986
pkg/sources/huggingface/huggingface_test.go
Normal file
|
@ -0,0 +1,986 @@
|
|||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
)
|
||||
|
||||
func createTestSource(src *sourcespb.Huggingface) (*Source, *anypb.Any) {
|
||||
s := &Source{}
|
||||
conn, err := anypb.New(src)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s, conn
|
||||
}
|
||||
|
||||
// test include exclude ignore/include orgs, users
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Models: []string{"user/model1", "user/model2", "user/ignorethismodel"},
|
||||
IgnoreModels: []string{"user/ignorethismodel"},
|
||||
Spaces: []string{"user/space1", "user/space2", "user/ignorethisspace"},
|
||||
IgnoreSpaces: []string{"user/ignorethisspace"},
|
||||
Datasets: []string{"user/dataset1", "user/dataset2", "user/ignorethisdataset"},
|
||||
IgnoreDatasets: []string{"user/ignorethisdataset"},
|
||||
Organizations: []string{"org1", "org2"},
|
||||
Users: []string{"user1", "user2"},
|
||||
})
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.ElementsMatch(t, []string{"user/model1", "user/model2"}, s.models)
|
||||
for _, model := range s.models {
|
||||
modelURL, _ := s.filteredModelsCache.Get(model)
|
||||
assert.Equal(t, modelURL, s.conn.Endpoint+"/"+model+".git")
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"user/space1", "user/space2"}, s.spaces)
|
||||
for _, space := range s.spaces {
|
||||
spaceURL, _ := s.filteredSpacesCache.Get(space)
|
||||
assert.Equal(t, spaceURL, s.conn.Endpoint+"/"+getResourceHTMLPath(SPACE)+"/"+space+".git")
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"user/dataset1", "user/dataset2"}, s.datasets)
|
||||
for _, dataset := range s.datasets {
|
||||
datasetURL, _ := s.filteredDatasetsCache.Get(dataset)
|
||||
assert.Equal(t, datasetURL, s.conn.Endpoint+"/"+getResourceHTMLPath(DATASET)+"/"+dataset+".git")
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, s.conn.Organizations, s.orgsCache.Keys())
|
||||
assert.ElementsMatch(t, s.conn.Users, s.usersCache.Keys())
|
||||
}
|
||||
|
||||
func TestGetResourceType(t *testing.T) {
|
||||
repo := "author/model1"
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Models: []string{repo},
|
||||
})
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/author/model1").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "author/model1",
|
||||
"author": "author",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, MODEL, s.getResourceType(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
|
||||
}
|
||||
|
||||
func TestVisibilityOf(t *testing.T) {
|
||||
repo := "author/model1"
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Models: []string{repo},
|
||||
})
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/author/model1").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "author/model1",
|
||||
"author": "author",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, source_metadatapb.Visibility(1), s.visibilityOf(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
|
||||
}
|
||||
|
||||
func TestEnumerate(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Models: []string{"author/model1"},
|
||||
Datasets: []string{"author/dataset1"},
|
||||
Spaces: []string{"author/space1"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/author/model1").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "author/model1",
|
||||
"author": "author",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/author/dataset1").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "author/dataset1",
|
||||
"author": "author",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/author/space1").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "author/space1",
|
||||
"author": "author",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/author/model1.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
|
||||
|
||||
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_private,
|
||||
resourceType: MODEL,
|
||||
owner: "author",
|
||||
name: "model1",
|
||||
fullName: "author/model1",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_public,
|
||||
resourceType: DATASET,
|
||||
owner: "author",
|
||||
name: "dataset1",
|
||||
fullName: "author/dataset1",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_public,
|
||||
resourceType: SPACE,
|
||||
owner: "author",
|
||||
name: "space1",
|
||||
fullName: "author/space1",
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestUpdateRepoList(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/author/model1.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
|
||||
|
||||
s.updateRepoLists(modelGitURL, MODEL)
|
||||
s.updateRepoLists(datasetGitURL, DATASET)
|
||||
s.updateRepoLists(spaceGitURL, SPACE)
|
||||
|
||||
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||
}
|
||||
|
||||
func TestGetReposListByType(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Models: []string{"author/model1", "author/model2"},
|
||||
Datasets: []string{"author/dataset1", "author/dataset2"},
|
||||
Spaces: []string{"author/space1", "author/space2"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, s.getReposListByType(MODEL), s.models)
|
||||
assert.Equal(t, s.getReposListByType(DATASET), s.datasets)
|
||||
assert.Equal(t, s.getReposListByType(SPACE), s.spaces)
|
||||
}
|
||||
|
||||
func TestGetVisibility(t *testing.T) {
|
||||
assert.Equal(t, source_metadatapb.Visibility(1), getVisibility(true))
|
||||
assert.Equal(t, source_metadatapb.Visibility(0), getVisibility(false))
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsOrg(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "1",
|
||||
"id": "org/model",
|
||||
"modelId": "org/model",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/org/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/model",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "3",
|
||||
"id": "org/dataset",
|
||||
"modelId": "org/dataset",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/org/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/dataset",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "5",
|
||||
"id": "org/space",
|
||||
"modelId": "org/space",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/org/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/space",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/org/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||
|
||||
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_private,
|
||||
resourceType: MODEL,
|
||||
owner: "org",
|
||||
name: "model",
|
||||
fullName: "org/model",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_private,
|
||||
resourceType: DATASET,
|
||||
owner: "org",
|
||||
name: "dataset",
|
||||
fullName: "org/dataset",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_private,
|
||||
resourceType: SPACE,
|
||||
owner: "org",
|
||||
name: "space",
|
||||
fullName: "org/space",
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsOrgSkipAll(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
SkipAllModels: true,
|
||||
SkipAllDatasets: true,
|
||||
SkipAllSpaces: true,
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "1",
|
||||
"id": "org/model",
|
||||
"modelId": "org/model",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/org/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/model",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "3",
|
||||
"id": "org/dataset",
|
||||
"modelId": "org/dataset",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/org/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/dataset",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "5",
|
||||
"id": "org/space",
|
||||
"modelId": "org/space",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/org/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/space",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/org/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||
|
||||
assert.Equal(t, []string{}, s.models)
|
||||
assert.Equal(t, []string{}, s.datasets)
|
||||
assert.Equal(t, []string{}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsOrgIgnores(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
IgnoreModels: []string{"org/model"},
|
||||
IgnoreDatasets: []string{"org/dataset"},
|
||||
IgnoreSpaces: []string{"org/space"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
// mock the request to the huggingface api
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "1",
|
||||
"id": "org/model",
|
||||
"modelId": "org/model",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/org/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/model",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "3",
|
||||
"id": "org/dataset",
|
||||
"modelId": "org/dataset",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/org/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/dataset",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "org").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "5",
|
||||
"id": "org/space",
|
||||
"modelId": "org/space",
|
||||
"private": true,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/org/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "org/space",
|
||||
"author": "org",
|
||||
"private": true,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/org/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||
|
||||
assert.Equal(t, []string{}, s.models)
|
||||
assert.Equal(t, []string{}, s.datasets)
|
||||
assert.Equal(t, []string{}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsUser(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Users: []string{"user"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "2",
|
||||
"id": "user/model",
|
||||
"modelId": "user/model",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/user/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/model",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "4",
|
||||
"id": "user/dataset",
|
||||
"modelId": "user/dataset",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/user/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/dataset",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "6",
|
||||
"id": "user/space",
|
||||
"modelId": "user/space",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/user/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/space",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/user/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||
|
||||
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_public,
|
||||
resourceType: MODEL,
|
||||
owner: "user",
|
||||
name: "model",
|
||||
fullName: "user/model",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_public,
|
||||
resourceType: DATASET,
|
||||
owner: "user",
|
||||
name: "dataset",
|
||||
fullName: "user/dataset",
|
||||
})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{
|
||||
visibility: source_metadatapb.Visibility_public,
|
||||
resourceType: SPACE,
|
||||
owner: "user",
|
||||
name: "space",
|
||||
fullName: "user/space",
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsUserSkipAll(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Users: []string{"user"},
|
||||
SkipAllModels: true,
|
||||
SkipAllDatasets: true,
|
||||
SkipAllSpaces: true,
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "2",
|
||||
"id": "user/model",
|
||||
"modelId": "user/model",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/user/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/model",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "4",
|
||||
"id": "user/dataset",
|
||||
"modelId": "user/dataset",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/user/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/dataset",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "6",
|
||||
"id": "user/space",
|
||||
"modelId": "user/space",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/user/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/space",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/user/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||
|
||||
assert.Equal(t, []string{}, s.models)
|
||||
assert.Equal(t, []string{}, s.datasets)
|
||||
assert.Equal(t, []string{}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
}
|
||||
|
||||
func TestEnumerateAuthorsUserIgnores(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Users: []string{"user"},
|
||||
IgnoreModels: []string{"user/model"},
|
||||
IgnoreDatasets: []string{"user/dataset"},
|
||||
IgnoreSpaces: []string{"user/space"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "2",
|
||||
"id": "user/model",
|
||||
"modelId": "user/model",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/models/user/model").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/model",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "4",
|
||||
"id": "user/dataset",
|
||||
"modelId": "user/dataset",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/datasets/user/dataset").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/dataset",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||
MatchParam("author", "user").
|
||||
MatchParam("limit", "1000").
|
||||
Reply(200).
|
||||
JSON([]map[string]interface{}{
|
||||
{
|
||||
"_id": "6",
|
||||
"id": "user/space",
|
||||
"modelId": "user/space",
|
||||
"private": false,
|
||||
},
|
||||
})
|
||||
|
||||
gock.New("https://huggingface.co").
|
||||
Get("/api/spaces/user/space").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"id": "user/space",
|
||||
"author": "user",
|
||||
"private": false,
|
||||
})
|
||||
|
||||
err = s.enumerate(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
modelGitURL := "https://huggingface.co/user/model.git"
|
||||
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||
|
||||
assert.Equal(t, []string{}, s.models)
|
||||
assert.Equal(t, []string{}, s.datasets)
|
||||
assert.Equal(t, []string{}, s.spaces)
|
||||
|
||||
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
|
||||
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||
assert.Equal(t, r, repoInfo{})
|
||||
}
|
||||
|
||||
func TestVerifySlashSeparatedStrings(t *testing.T) {
|
||||
assert.Error(t, verifySlashSeparatedStrings([]string{"orgmodel"}))
|
||||
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model"}))
|
||||
assert.Error(t, verifySlashSeparatedStrings([]string{"org/model", "orgmodel2"}))
|
||||
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model", "org/model2"}))
|
||||
}
|
||||
|
||||
func TestValidateIgnoreIncludeReposRepos(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
IgnoreModels: []string{"orgmodel1"},
|
||||
IgnoreDatasets: []string{"org/dataset1"},
|
||||
IgnoreSpaces: []string{"org/space1"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestValidateIgnoreIncludeReposDatasets(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
IgnoreModels: []string{"org/model1"},
|
||||
IgnoreDatasets: []string{"orgdataset1"},
|
||||
IgnoreSpaces: []string{"org/space1"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestValidateIgnoreIncludeReposSpaces(t *testing.T) {
|
||||
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||
Endpoint: "https://huggingface.co",
|
||||
Credential: &sourcespb.Huggingface_Token{
|
||||
Token: "super secret token",
|
||||
},
|
||||
Organizations: []string{"org"},
|
||||
IgnoreModels: []string{"org/model1"},
|
||||
IgnoreDatasets: []string{"org/dataset1"},
|
||||
IgnoreSpaces: []string{"orgspace1"},
|
||||
})
|
||||
|
||||
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// repeat this with all skip flags, and then include/ignore flags
|
72
pkg/sources/huggingface/repo.go
Normal file
72
pkg/sources/huggingface/repo.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
gogit "github.com/go-git/go-git/v5"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
|
||||
)
|
||||
|
||||
type repoInfoCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]repoInfo
|
||||
}
|
||||
|
||||
func newRepoInfoCache() repoInfoCache {
|
||||
return repoInfoCache{
|
||||
cache: make(map[string]repoInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *repoInfoCache) put(repoURL string, info repoInfo) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.cache[repoURL] = info
|
||||
}
|
||||
|
||||
func (r *repoInfoCache) get(repoURL string) (repoInfo, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
info, ok := r.cache[repoURL]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
type repoInfo struct {
|
||||
owner string
|
||||
name string
|
||||
fullName string
|
||||
visibility source_metadatapb.Visibility
|
||||
resourceType resourceType
|
||||
}
|
||||
|
||||
func (s *Source) cloneRepo(
|
||||
ctx context.Context,
|
||||
repoURL string,
|
||||
) (string, *gogit.Repository, error) {
|
||||
var (
|
||||
path string
|
||||
repo *gogit.Repository
|
||||
err error
|
||||
)
|
||||
|
||||
switch s.conn.GetCredential().(type) {
|
||||
case *sourcespb.Huggingface_Unauthenticated:
|
||||
path, repo, err = git.CloneRepoUsingUnauthenticated(ctx, repoURL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
case *sourcespb.Huggingface_Token:
|
||||
path, repo, err = git.CloneRepoUsingToken(ctx, s.huggingfaceToken, repoURL, "")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unhandled credential type for repo %s: %T", repoURL, s.conn.GetCredential())
|
||||
}
|
||||
return path, repo, nil
|
||||
}
|
|
@ -136,6 +136,19 @@ message GCS {
|
|||
string content_type = 8;
|
||||
}
|
||||
|
||||
message Huggingface {
|
||||
string link = 1;
|
||||
string username = 2;
|
||||
string repository = 3;
|
||||
string commit = 4;
|
||||
string email = 5;
|
||||
string file = 6;
|
||||
string timestamp = 7;
|
||||
int64 line = 8;
|
||||
Visibility visibility = 9;
|
||||
string resource_type = 10;
|
||||
}
|
||||
|
||||
message Jira {
|
||||
string issue = 1;
|
||||
string author = 2;
|
||||
|
@ -351,5 +364,6 @@ message MetaData {
|
|||
Postman postman = 29;
|
||||
Webhook webhook = 30;
|
||||
Elasticsearch elasticsearch = 31;
|
||||
Huggingface huggingface = 32;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ enum SourceType {
|
|||
SOURCE_TYPE_POSTMAN = 33;
|
||||
SOURCE_TYPE_WEBHOOK = 34;
|
||||
SOURCE_TYPE_ELASTICSEARCH = 35;
|
||||
SOURCE_TYPE_HUGGINGFACE = 36;
|
||||
}
|
||||
|
||||
message LocalSource {
|
||||
|
@ -248,6 +249,30 @@ message GoogleDrive {
|
|||
}
|
||||
}
|
||||
|
||||
message Huggingface {
|
||||
string endpoint = 1 [(validate.rules).string.uri_ref = true];
|
||||
oneof credential {
|
||||
string token = 2;
|
||||
credentials.Unauthenticated unauthenticated = 3;
|
||||
}
|
||||
repeated string models = 4;
|
||||
repeated string spaces = 5;
|
||||
repeated string datasets = 12;
|
||||
repeated string organizations = 6;
|
||||
repeated string users = 15;
|
||||
repeated string ignore_models = 7;
|
||||
repeated string include_models = 8;
|
||||
repeated string ignore_spaces = 9;
|
||||
repeated string include_spaces = 10;
|
||||
repeated string ignore_datasets = 13;
|
||||
repeated string include_datasets = 14;
|
||||
bool skip_all_models = 16;
|
||||
bool skip_all_spaces = 17;
|
||||
bool skip_all_datasets = 18;
|
||||
bool include_discussions = 11;
|
||||
bool include_prs = 19;
|
||||
}
|
||||
|
||||
message JIRA {
|
||||
string endpoint = 1 [(validate.rules).string.uri_ref = true];
|
||||
oneof credential {
|
||||
|
|
Loading…
Reference in a new issue