mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
New Source: HuggingFace (#3000)
* initial spike on hf * added in user and org enum * adding huggingface source * updated with lint suggestions * updated readme * addressing resources that require org approval to access * removing unneeded code * updating with new error msg for 403 * deleted unused code + added resource check in main
This commit is contained in:
parent
e9206c66bb
commit
01a1499600
14 changed files with 4568 additions and 964 deletions
21
README.md
21
README.md
|
@ -301,6 +301,27 @@ trufflehog elasticsearch \
|
||||||
--api-key 'MlVtVjBZ...ZSYlduYnF1djh3NG5FQQ=='
|
--api-key 'MlVtVjBZ...ZSYlduYnF1djh3NG5FQQ=='
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 15. Scan HuggingFace
|
||||||
|
|
||||||
|
### Scan a HuggingFace Model, Dataset or Space
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trufflehog huggingface --model <username/modelname> --space <username/spacename> --dataset <username/datasetname>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scan all Models, Datasets and Space belonging to a HuggingFace Org/User
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trufflehog huggingface --org <orgname> --user <username>
|
||||||
|
```
|
||||||
|
|
||||||
|
Optionally, skip scanning a type of resource with `--skip-models`, `--skip-datasets`, `--skip-spaces` or a particular resource with `--ignore-models/datasets/spaces <resource-name>`.
|
||||||
|
|
||||||
|
### Scan Discussion and PR Comments
|
||||||
|
```bash
|
||||||
|
trufflehog huggingface --model <username/modelname> --include-discussions --include-prs
|
||||||
|
```
|
||||||
|
|
||||||
# :question: FAQ
|
# :question: FAQ
|
||||||
|
|
||||||
- All I see is `🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷` and the program exits, what gives?
|
- All I see is `🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷` and the program exits, what gives?
|
||||||
|
|
54
main.go
54
main.go
|
@ -198,6 +198,27 @@ var (
|
||||||
jenkinsPassword = jenkinsScan.Flag("password", "Jenkins password").Envar("JENKINS_PASSWORD").String()
|
jenkinsPassword = jenkinsScan.Flag("password", "Jenkins password").Envar("JENKINS_PASSWORD").String()
|
||||||
jenkinsInsecureSkipVerifyTLS = jenkinsScan.Flag("insecure-skip-verify-tls", "Skip TLS verification").Envar("JENKINS_INSECURE_SKIP_VERIFY_TLS").Bool()
|
jenkinsInsecureSkipVerifyTLS = jenkinsScan.Flag("insecure-skip-verify-tls", "Skip TLS verification").Envar("JENKINS_INSECURE_SKIP_VERIFY_TLS").Bool()
|
||||||
|
|
||||||
|
huggingfaceScan = cli.Command("huggingface", "Find credentials in HuggingFace datasets, models and spaces.")
|
||||||
|
huggingfaceEndpoint = huggingfaceScan.Flag("endpoint", "HuggingFace endpoint.").Default("https://huggingface.co").String()
|
||||||
|
huggingfaceModels = huggingfaceScan.Flag("model", "HuggingFace model to scan. You can repeat this flag. Example: 'username/model'").Strings()
|
||||||
|
huggingfaceSpaces = huggingfaceScan.Flag("space", "HuggingFace space to scan. You can repeat this flag. Example: 'username/space'").Strings()
|
||||||
|
huggingfaceDatasets = huggingfaceScan.Flag("dataset", "HuggingFace dataset to scan. You can repeat this flag. Example: 'username/dataset'").Strings()
|
||||||
|
huggingfaceOrgs = huggingfaceScan.Flag("org", `HuggingFace organization to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
|
||||||
|
huggingfaceUsers = huggingfaceScan.Flag("user", `HuggingFace user to scan. You can repeat this flag. Example: "trufflesecurity"`).Strings()
|
||||||
|
huggingfaceToken = huggingfaceScan.Flag("token", "HuggingFace token. Can be provided with environment variable HUGGINGFACE_TOKEN.").Envar("HUGGINGFACE_TOKEN").String()
|
||||||
|
|
||||||
|
huggingfaceIncludeModels = huggingfaceScan.Flag("include-models", "Models to include in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceIncludeSpaces = huggingfaceScan.Flag("include-spaces", "Spaces to include in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceIncludeDatasets = huggingfaceScan.Flag("include-datasets", "Datasets to include in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceIgnoreModels = huggingfaceScan.Flag("ignore-models", "Models to ignore in scan. You can repeat this flag. Must use HuggingFace model full name. Example: 'username/model' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceIgnoreSpaces = huggingfaceScan.Flag("ignore-spaces", "Spaces to ignore in scan. You can repeat this flag. Must use HuggingFace space full name. Example: 'username/space' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceIgnoreDatasets = huggingfaceScan.Flag("ignore-datasets", "Datasets to ignore in scan. You can repeat this flag. Must use HuggingFace dataset full name. Example: 'username/dataset' (Only used with --user or --org)").Strings()
|
||||||
|
huggingfaceSkipAllModels = huggingfaceScan.Flag("skip-all-models", "Skip all model scans. (Only used with --user or --org)").Bool()
|
||||||
|
huggingfaceSkipAllSpaces = huggingfaceScan.Flag("skip-all-spaces", "Skip all space scans. (Only used with --user or --org)").Bool()
|
||||||
|
huggingfaceSkipAllDatasets = huggingfaceScan.Flag("skip-all-datasets", "Skip all dataset scans. (Only used with --user or --org)").Bool()
|
||||||
|
huggingfaceIncludeDiscussions = huggingfaceScan.Flag("include-discussions", "Include discussions in scan.").Bool()
|
||||||
|
huggingfaceIncludePrs = huggingfaceScan.Flag("include-prs", "Include pull requests in scan.").Bool()
|
||||||
|
|
||||||
usingTUI = false
|
usingTUI = false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -738,6 +759,39 @@ func runSingleScan(ctx context.Context, cmd string, cfg engine.Config) (metrics,
|
||||||
if err := eng.ScanJenkins(ctx, cfg); err != nil {
|
if err := eng.ScanJenkins(ctx, cfg); err != nil {
|
||||||
return scanMetrics, fmt.Errorf("failed to scan Jenkins: %v", err)
|
return scanMetrics, fmt.Errorf("failed to scan Jenkins: %v", err)
|
||||||
}
|
}
|
||||||
|
case huggingfaceScan.FullCommand():
|
||||||
|
if *huggingfaceEndpoint != "" {
|
||||||
|
*huggingfaceEndpoint = strings.TrimRight(*huggingfaceEndpoint, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*huggingfaceModels) == 0 && len(*huggingfaceSpaces) == 0 && len(*huggingfaceDatasets) == 0 && len(*huggingfaceOrgs) == 0 && len(*huggingfaceUsers) == 0 {
|
||||||
|
return scanMetrics, fmt.Errorf("invalid config: you must specify at least one organization, user, model, space or dataset")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := engine.HuggingfaceConfig{
|
||||||
|
Endpoint: *huggingfaceEndpoint,
|
||||||
|
Models: *huggingfaceModels,
|
||||||
|
Spaces: *huggingfaceSpaces,
|
||||||
|
Datasets: *huggingfaceDatasets,
|
||||||
|
Organizations: *huggingfaceOrgs,
|
||||||
|
Users: *huggingfaceUsers,
|
||||||
|
Token: *huggingfaceToken,
|
||||||
|
IncludeModels: *huggingfaceIncludeModels,
|
||||||
|
IncludeSpaces: *huggingfaceIncludeSpaces,
|
||||||
|
IncludeDatasets: *huggingfaceIncludeDatasets,
|
||||||
|
IgnoreModels: *huggingfaceIgnoreModels,
|
||||||
|
IgnoreSpaces: *huggingfaceIgnoreSpaces,
|
||||||
|
IgnoreDatasets: *huggingfaceIgnoreDatasets,
|
||||||
|
SkipAllModels: *huggingfaceSkipAllModels,
|
||||||
|
SkipAllSpaces: *huggingfaceSkipAllSpaces,
|
||||||
|
SkipAllDatasets: *huggingfaceSkipAllDatasets,
|
||||||
|
IncludeDiscussions: *huggingfaceIncludeDiscussions,
|
||||||
|
IncludePrs: *huggingfaceIncludePrs,
|
||||||
|
Concurrency: *concurrency,
|
||||||
|
}
|
||||||
|
if err := eng.ScanHuggingface(ctx, cfg); err != nil {
|
||||||
|
return scanMetrics, fmt.Errorf("failed to scan HuggingFace: %v", err)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return scanMetrics, fmt.Errorf("invalid command: %s", cmd)
|
return scanMetrics, fmt.Errorf("invalid command: %s", cmd)
|
||||||
}
|
}
|
||||||
|
|
80
pkg/engine/huggingface.go
Normal file
80
pkg/engine/huggingface.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/huggingface"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HuggingFaceConfig represents the configuration for HuggingFace.
|
||||||
|
type HuggingfaceConfig struct {
|
||||||
|
Endpoint string
|
||||||
|
Models []string
|
||||||
|
Spaces []string
|
||||||
|
Datasets []string
|
||||||
|
Organizations []string
|
||||||
|
Users []string
|
||||||
|
IncludeModels []string
|
||||||
|
IgnoreModels []string
|
||||||
|
IncludeSpaces []string
|
||||||
|
IgnoreSpaces []string
|
||||||
|
IncludeDatasets []string
|
||||||
|
IgnoreDatasets []string
|
||||||
|
SkipAllModels bool
|
||||||
|
SkipAllSpaces bool
|
||||||
|
SkipAllDatasets bool
|
||||||
|
IncludeDiscussions bool
|
||||||
|
IncludePrs bool
|
||||||
|
Token string
|
||||||
|
Concurrency int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanGitHub scans HuggingFace with the provided options.
|
||||||
|
func (e *Engine) ScanHuggingface(ctx context.Context, c HuggingfaceConfig) error {
|
||||||
|
connection := sourcespb.Huggingface{
|
||||||
|
Endpoint: c.Endpoint,
|
||||||
|
Models: c.Models,
|
||||||
|
Spaces: c.Spaces,
|
||||||
|
Datasets: c.Datasets,
|
||||||
|
Organizations: c.Organizations,
|
||||||
|
Users: c.Users,
|
||||||
|
IncludeModels: c.IncludeModels,
|
||||||
|
IgnoreModels: c.IgnoreModels,
|
||||||
|
IncludeSpaces: c.IncludeSpaces,
|
||||||
|
IgnoreSpaces: c.IgnoreSpaces,
|
||||||
|
IncludeDatasets: c.IncludeDatasets,
|
||||||
|
IgnoreDatasets: c.IgnoreDatasets,
|
||||||
|
SkipAllModels: c.SkipAllModels,
|
||||||
|
SkipAllSpaces: c.SkipAllSpaces,
|
||||||
|
SkipAllDatasets: c.SkipAllDatasets,
|
||||||
|
IncludeDiscussions: c.IncludeDiscussions,
|
||||||
|
IncludePrs: c.IncludePrs,
|
||||||
|
}
|
||||||
|
if len(c.Token) > 0 {
|
||||||
|
connection.Credential = &sourcespb.Huggingface_Token{
|
||||||
|
Token: c.Token,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
connection.Credential = &sourcespb.Huggingface_Unauthenticated{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var conn anypb.Any
|
||||||
|
err := anypb.MarshalFrom(&conn, &connection, proto.MarshalOptions{})
|
||||||
|
if err != nil {
|
||||||
|
ctx.Logger().Error(err, "failed to marshal huggingface connection")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceName := "trufflehog - huggingface"
|
||||||
|
sourceID, jobID, _ := e.sourceManager.GetIDs(ctx, sourceName, sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE)
|
||||||
|
|
||||||
|
huggingfaceSource := &huggingface.Source{}
|
||||||
|
if err := huggingfaceSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = e.sourceManager.Run(ctx, sourceName, huggingfaceSource)
|
||||||
|
return err
|
||||||
|
}
|
File diff suppressed because it is too large
Load diff
|
@ -1493,6 +1493,125 @@ var _ interface {
|
||||||
ErrorName() string
|
ErrorName() string
|
||||||
} = GCSValidationError{}
|
} = GCSValidationError{}
|
||||||
|
|
||||||
|
// Validate checks the field values on Huggingface with the rules defined in
|
||||||
|
// the proto definition for this message. If any rules are violated, the first
|
||||||
|
// error encountered is returned, or nil if there are no violations.
|
||||||
|
func (m *Huggingface) Validate() error {
|
||||||
|
return m.validate(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAll checks the field values on Huggingface with the rules defined in
|
||||||
|
// the proto definition for this message. If any rules are violated, the
|
||||||
|
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
|
||||||
|
// nil if none found.
|
||||||
|
func (m *Huggingface) ValidateAll() error {
|
||||||
|
return m.validate(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Huggingface) validate(all bool) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
|
||||||
|
// no validation rules for Link
|
||||||
|
|
||||||
|
// no validation rules for Username
|
||||||
|
|
||||||
|
// no validation rules for Repository
|
||||||
|
|
||||||
|
// no validation rules for Commit
|
||||||
|
|
||||||
|
// no validation rules for Email
|
||||||
|
|
||||||
|
// no validation rules for File
|
||||||
|
|
||||||
|
// no validation rules for Timestamp
|
||||||
|
|
||||||
|
// no validation rules for Line
|
||||||
|
|
||||||
|
// no validation rules for Visibility
|
||||||
|
|
||||||
|
// no validation rules for ResourceType
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return HuggingfaceMultiError(errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HuggingfaceMultiError is an error wrapping multiple validation errors
|
||||||
|
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
|
||||||
|
type HuggingfaceMultiError []error
|
||||||
|
|
||||||
|
// Error returns a concatenation of all the error messages it wraps.
|
||||||
|
func (m HuggingfaceMultiError) Error() string {
|
||||||
|
var msgs []string
|
||||||
|
for _, err := range m {
|
||||||
|
msgs = append(msgs, err.Error())
|
||||||
|
}
|
||||||
|
return strings.Join(msgs, "; ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllErrors returns a list of validation violation errors.
|
||||||
|
func (m HuggingfaceMultiError) AllErrors() []error { return m }
|
||||||
|
|
||||||
|
// HuggingfaceValidationError is the validation error returned by
|
||||||
|
// Huggingface.Validate if the designated constraints aren't met.
|
||||||
|
type HuggingfaceValidationError struct {
|
||||||
|
field string
|
||||||
|
reason string
|
||||||
|
cause error
|
||||||
|
key bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Field function returns field value.
|
||||||
|
func (e HuggingfaceValidationError) Field() string { return e.field }
|
||||||
|
|
||||||
|
// Reason function returns reason value.
|
||||||
|
func (e HuggingfaceValidationError) Reason() string { return e.reason }
|
||||||
|
|
||||||
|
// Cause function returns cause value.
|
||||||
|
func (e HuggingfaceValidationError) Cause() error { return e.cause }
|
||||||
|
|
||||||
|
// Key function returns key value.
|
||||||
|
func (e HuggingfaceValidationError) Key() bool { return e.key }
|
||||||
|
|
||||||
|
// ErrorName returns error name.
|
||||||
|
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
|
||||||
|
|
||||||
|
// Error satisfies the builtin error interface
|
||||||
|
func (e HuggingfaceValidationError) Error() string {
|
||||||
|
cause := ""
|
||||||
|
if e.cause != nil {
|
||||||
|
cause = fmt.Sprintf(" | caused by: %v", e.cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ""
|
||||||
|
if e.key {
|
||||||
|
key = "key for "
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"invalid %sHuggingface.%s: %s%s",
|
||||||
|
key,
|
||||||
|
e.field,
|
||||||
|
e.reason,
|
||||||
|
cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ error = HuggingfaceValidationError{}
|
||||||
|
|
||||||
|
var _ interface {
|
||||||
|
Field() string
|
||||||
|
Reason() string
|
||||||
|
Key() bool
|
||||||
|
Cause() error
|
||||||
|
ErrorName() string
|
||||||
|
} = HuggingfaceValidationError{}
|
||||||
|
|
||||||
// Validate checks the field values on Jira with the rules defined in the proto
|
// Validate checks the field values on Jira with the rules defined in the proto
|
||||||
// definition for this message. If any rules are violated, the first error
|
// definition for this message. If any rules are violated, the first error
|
||||||
// encountered is returned, or nil if there are no violations.
|
// encountered is returned, or nil if there are no violations.
|
||||||
|
@ -5074,6 +5193,47 @@ func (m *MetaData) validate(all bool) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case *MetaData_Huggingface:
|
||||||
|
if v == nil {
|
||||||
|
err := MetaDataValidationError{
|
||||||
|
field: "Data",
|
||||||
|
reason: "oneof value cannot be a typed-nil",
|
||||||
|
}
|
||||||
|
if !all {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if all {
|
||||||
|
switch v := interface{}(m.GetHuggingface()).(type) {
|
||||||
|
case interface{ ValidateAll() error }:
|
||||||
|
if err := v.ValidateAll(); err != nil {
|
||||||
|
errors = append(errors, MetaDataValidationError{
|
||||||
|
field: "Huggingface",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case interface{ Validate() error }:
|
||||||
|
if err := v.Validate(); err != nil {
|
||||||
|
errors = append(errors, MetaDataValidationError{
|
||||||
|
field: "Huggingface",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if v, ok := interface{}(m.GetHuggingface()).(interface{ Validate() error }); ok {
|
||||||
|
if err := v.Validate(); err != nil {
|
||||||
|
return MetaDataValidationError{
|
||||||
|
field: "Huggingface",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
_ = v // ensures v is used
|
_ = v // ensures v is used
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2875,6 +2875,185 @@ var _ interface {
|
||||||
ErrorName() string
|
ErrorName() string
|
||||||
} = GoogleDriveValidationError{}
|
} = GoogleDriveValidationError{}
|
||||||
|
|
||||||
|
// Validate checks the field values on Huggingface with the rules defined in
|
||||||
|
// the proto definition for this message. If any rules are violated, the first
|
||||||
|
// error encountered is returned, or nil if there are no violations.
|
||||||
|
func (m *Huggingface) Validate() error {
|
||||||
|
return m.validate(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAll checks the field values on Huggingface with the rules defined in
|
||||||
|
// the proto definition for this message. If any rules are violated, the
|
||||||
|
// result is a list of violation errors wrapped in HuggingfaceMultiError, or
|
||||||
|
// nil if none found.
|
||||||
|
func (m *Huggingface) ValidateAll() error {
|
||||||
|
return m.validate(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Huggingface) validate(all bool) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
|
||||||
|
if _, err := url.Parse(m.GetEndpoint()); err != nil {
|
||||||
|
err = HuggingfaceValidationError{
|
||||||
|
field: "Endpoint",
|
||||||
|
reason: "value must be a valid URI",
|
||||||
|
cause: err,
|
||||||
|
}
|
||||||
|
if !all {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// no validation rules for SkipAllModels
|
||||||
|
|
||||||
|
// no validation rules for SkipAllSpaces
|
||||||
|
|
||||||
|
// no validation rules for SkipAllDatasets
|
||||||
|
|
||||||
|
// no validation rules for IncludeDiscussions
|
||||||
|
|
||||||
|
// no validation rules for IncludePrs
|
||||||
|
|
||||||
|
switch v := m.Credential.(type) {
|
||||||
|
case *Huggingface_Token:
|
||||||
|
if v == nil {
|
||||||
|
err := HuggingfaceValidationError{
|
||||||
|
field: "Credential",
|
||||||
|
reason: "oneof value cannot be a typed-nil",
|
||||||
|
}
|
||||||
|
if !all {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
// no validation rules for Token
|
||||||
|
case *Huggingface_Unauthenticated:
|
||||||
|
if v == nil {
|
||||||
|
err := HuggingfaceValidationError{
|
||||||
|
field: "Credential",
|
||||||
|
reason: "oneof value cannot be a typed-nil",
|
||||||
|
}
|
||||||
|
if !all {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if all {
|
||||||
|
switch v := interface{}(m.GetUnauthenticated()).(type) {
|
||||||
|
case interface{ ValidateAll() error }:
|
||||||
|
if err := v.ValidateAll(); err != nil {
|
||||||
|
errors = append(errors, HuggingfaceValidationError{
|
||||||
|
field: "Unauthenticated",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case interface{ Validate() error }:
|
||||||
|
if err := v.Validate(); err != nil {
|
||||||
|
errors = append(errors, HuggingfaceValidationError{
|
||||||
|
field: "Unauthenticated",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if v, ok := interface{}(m.GetUnauthenticated()).(interface{ Validate() error }); ok {
|
||||||
|
if err := v.Validate(); err != nil {
|
||||||
|
return HuggingfaceValidationError{
|
||||||
|
field: "Unauthenticated",
|
||||||
|
reason: "embedded message failed validation",
|
||||||
|
cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
_ = v // ensures v is used
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return HuggingfaceMultiError(errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HuggingfaceMultiError is an error wrapping multiple validation errors
|
||||||
|
// returned by Huggingface.ValidateAll() if the designated constraints aren't met.
|
||||||
|
type HuggingfaceMultiError []error
|
||||||
|
|
||||||
|
// Error returns a concatenation of all the error messages it wraps.
|
||||||
|
func (m HuggingfaceMultiError) Error() string {
|
||||||
|
var msgs []string
|
||||||
|
for _, err := range m {
|
||||||
|
msgs = append(msgs, err.Error())
|
||||||
|
}
|
||||||
|
return strings.Join(msgs, "; ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllErrors returns a list of validation violation errors.
|
||||||
|
func (m HuggingfaceMultiError) AllErrors() []error { return m }
|
||||||
|
|
||||||
|
// HuggingfaceValidationError is the validation error returned by
|
||||||
|
// Huggingface.Validate if the designated constraints aren't met.
|
||||||
|
type HuggingfaceValidationError struct {
|
||||||
|
field string
|
||||||
|
reason string
|
||||||
|
cause error
|
||||||
|
key bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Field function returns field value.
|
||||||
|
func (e HuggingfaceValidationError) Field() string { return e.field }
|
||||||
|
|
||||||
|
// Reason function returns reason value.
|
||||||
|
func (e HuggingfaceValidationError) Reason() string { return e.reason }
|
||||||
|
|
||||||
|
// Cause function returns cause value.
|
||||||
|
func (e HuggingfaceValidationError) Cause() error { return e.cause }
|
||||||
|
|
||||||
|
// Key function returns key value.
|
||||||
|
func (e HuggingfaceValidationError) Key() bool { return e.key }
|
||||||
|
|
||||||
|
// ErrorName returns error name.
|
||||||
|
func (e HuggingfaceValidationError) ErrorName() string { return "HuggingfaceValidationError" }
|
||||||
|
|
||||||
|
// Error satisfies the builtin error interface
|
||||||
|
func (e HuggingfaceValidationError) Error() string {
|
||||||
|
cause := ""
|
||||||
|
if e.cause != nil {
|
||||||
|
cause = fmt.Sprintf(" | caused by: %v", e.cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ""
|
||||||
|
if e.key {
|
||||||
|
key = "key for "
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"invalid %sHuggingface.%s: %s%s",
|
||||||
|
key,
|
||||||
|
e.field,
|
||||||
|
e.reason,
|
||||||
|
cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ error = HuggingfaceValidationError{}
|
||||||
|
|
||||||
|
var _ interface {
|
||||||
|
Field() string
|
||||||
|
Reason() string
|
||||||
|
Key() bool
|
||||||
|
Cause() error
|
||||||
|
ErrorName() string
|
||||||
|
} = HuggingfaceValidationError{}
|
||||||
|
|
||||||
// Validate checks the field values on JIRA with the rules defined in the proto
|
// Validate checks the field values on JIRA with the rules defined in the proto
|
||||||
// definition for this message. If any rules are violated, the first error
|
// definition for this message. If any rules are violated, the first error
|
||||||
// encountered is returned, or nil if there are no violations.
|
// encountered is returned, or nil if there are no violations.
|
||||||
|
|
223
pkg/sources/huggingface/client.go
Normal file
223
pkg/sources/huggingface/client.go
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
package huggingface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Maps for API and HTML paths
|
||||||
|
var apiPaths = map[string]string{
|
||||||
|
DATASET: DatasetsRoute,
|
||||||
|
MODEL: ModelsAPIRoute,
|
||||||
|
SPACE: SpacesRoute,
|
||||||
|
}
|
||||||
|
|
||||||
|
var htmlPaths = map[string]string{
|
||||||
|
DATASET: DatasetsRoute,
|
||||||
|
MODEL: "",
|
||||||
|
SPACE: SpacesRoute,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Author struct {
|
||||||
|
Username string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Latest struct {
|
||||||
|
Raw string `json:"raw"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Data struct {
|
||||||
|
Latest Latest `json:"latest"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Author Author `json:"author"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
Data Data `json:"data"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Event) GetAuthor() string {
|
||||||
|
return e.Author.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Event) GetCreatedAt() string {
|
||||||
|
return e.CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Event) GetID() string {
|
||||||
|
return e.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
type RepoData struct {
|
||||||
|
FullName string `json:"name"`
|
||||||
|
ResourceType string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Discussion struct {
|
||||||
|
ID int `json:"num"`
|
||||||
|
IsPR bool `json:"isPullRequest"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Events []Event `json:"events"`
|
||||||
|
Repo RepoData `json:"repo"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Discussion) GetID() string {
|
||||||
|
return fmt.Sprint(d.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Discussion) GetTitle() string {
|
||||||
|
return d.Title
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Discussion) GetCreatedAt() string {
|
||||||
|
return d.CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Discussion) GetRepo() string {
|
||||||
|
return d.Repo.FullName
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDiscussionPath returns the path (ex: "/models/user/repo/discussions/1") for the discussion
|
||||||
|
func (d Discussion) GetDiscussionPath() string {
|
||||||
|
basePath := fmt.Sprintf("%s/%s/%s", d.GetRepo(), DiscussionsRoute, d.GetID())
|
||||||
|
if d.Repo.ResourceType == "model" {
|
||||||
|
return basePath
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGitPath returns the path (ex: "/models/user/repo.git") for the repo's git directory
|
||||||
|
func (d Discussion) GetGitPath() string {
|
||||||
|
basePath := fmt.Sprintf("%s.git", d.GetRepo())
|
||||||
|
if d.Repo.ResourceType == "model" {
|
||||||
|
return basePath
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s", getResourceHTMLPath(d.Repo.ResourceType), basePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscussionList struct {
|
||||||
|
Discussions []Discussion `json:"discussions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Repo struct {
|
||||||
|
IsPrivate bool `json:"private"`
|
||||||
|
Owner string `json:"author"`
|
||||||
|
RepoID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HFClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new API client
|
||||||
|
func NewHFClient(baseURL, apiKey string, timeout time.Duration) *HFClient {
|
||||||
|
return &HFClient{
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: apiKey,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get makes a GET request to the Hugging Face API
|
||||||
|
// Note: not addressing rate limit, since it seems very permissive. (ex: "If \
|
||||||
|
// your account suddenly sends 10k requests then you’re likely to receive 503")
|
||||||
|
func (c *HFClient) get(ctx context.Context, url string, target interface{}) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create HuggingFace API request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to make request to HuggingFace API: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
return errors.New("invalid API key.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
return errors.New("access to this repo is restricted and you are not in the authorized list. Visit the repository to ask for access.")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return json.NewDecoder(resp.Body).Decode(target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepo retrieves repo from the Hugging Face API
|
||||||
|
func (c *HFClient) GetRepo(ctx context.Context, repoName string, resourceType string) (Repo, error) {
|
||||||
|
var repo Repo
|
||||||
|
url, err := buildAPIURL(c.BaseURL, resourceType, repoName)
|
||||||
|
if err != nil {
|
||||||
|
return repo, err
|
||||||
|
}
|
||||||
|
err = c.get(ctx, url, &repo)
|
||||||
|
return repo, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListDiscussions retrieves discussions from the Hugging Face API
|
||||||
|
func (c *HFClient) ListDiscussions(ctx context.Context, repoInfo repoInfo) (DiscussionList, error) {
|
||||||
|
var discussions DiscussionList
|
||||||
|
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
|
||||||
|
if err != nil {
|
||||||
|
return discussions, err
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s", baseURL, DiscussionsRoute)
|
||||||
|
err = c.get(ctx, url, &discussions)
|
||||||
|
return discussions, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HFClient) GetDiscussionByID(ctx context.Context, repoInfo repoInfo, discussionID string) (Discussion, error) {
|
||||||
|
var discussion Discussion
|
||||||
|
baseURL, err := buildAPIURL(c.BaseURL, string(repoInfo.resourceType), repoInfo.fullName)
|
||||||
|
if err != nil {
|
||||||
|
return discussion, err
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s/%s", baseURL, DiscussionsRoute, discussionID)
|
||||||
|
err = c.get(ctx, url, &discussion)
|
||||||
|
return discussion, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListReposByAuthor retrieves repos from the Hugging Face API by author (user or org)
|
||||||
|
// Note: not addressing pagination b/c allow by default 1000 results, which should be enough for 99.99% of cases
|
||||||
|
func (c *HFClient) ListReposByAuthor(ctx context.Context, resourceType string, author string) ([]Repo, error) {
|
||||||
|
var repos []Repo
|
||||||
|
url := fmt.Sprintf("%s/%s/%s?limit=1000&author=%s", c.BaseURL, APIRoute, getResourceAPIPath(resourceType), author)
|
||||||
|
err := c.get(ctx, url, &repos)
|
||||||
|
return repos, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getResourceAPIPath returns the API path for the given resource type
|
||||||
|
func getResourceAPIPath(resourceType string) string {
|
||||||
|
return apiPaths[resourceType]
|
||||||
|
}
|
||||||
|
|
||||||
|
// getResourceHTMLPath returns the HTML path for the given resource type
|
||||||
|
func getResourceHTMLPath(resourceType string) string {
|
||||||
|
return htmlPaths[resourceType]
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAPIURL(endpoint string, resourceType string, repoName string) (string, error) {
|
||||||
|
if endpoint == "" || resourceType == "" || repoName == "" {
|
||||||
|
return "", errors.New("endpoint, resourceType, and repoName must not be empty")
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/%s/%s", endpoint, APIRoute, getResourceAPIPath(resourceType), repoName), nil
|
||||||
|
}
|
680
pkg/sources/huggingface/huggingface.go
Normal file
680
pkg/sources/huggingface/huggingface.go
Normal file
|
@ -0,0 +1,680 @@
|
||||||
|
package huggingface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
"github.com/gobwas/glob"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SourceType = sourcespb.SourceType_SOURCE_TYPE_HUGGINGFACE
|
||||||
|
DatasetsRoute = "datasets"
|
||||||
|
SpacesRoute = "spaces"
|
||||||
|
ModelsAPIRoute = "models"
|
||||||
|
DiscussionsRoute = "discussions"
|
||||||
|
APIRoute = "api"
|
||||||
|
DATASET = "dataset"
|
||||||
|
MODEL = "model"
|
||||||
|
SPACE = "space"
|
||||||
|
defaultPagination = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
type resourceType string
|
||||||
|
|
||||||
|
type Source struct {
|
||||||
|
name string
|
||||||
|
huggingfaceToken string
|
||||||
|
|
||||||
|
sourceID sources.SourceID
|
||||||
|
jobID sources.JobID
|
||||||
|
verify bool
|
||||||
|
useCustomContentWriter bool
|
||||||
|
orgsCache cache.Cache[string]
|
||||||
|
usersCache cache.Cache[string]
|
||||||
|
|
||||||
|
models []string
|
||||||
|
spaces []string
|
||||||
|
datasets []string
|
||||||
|
|
||||||
|
filteredModelsCache *filteredRepoCache
|
||||||
|
filteredSpacesCache *filteredRepoCache
|
||||||
|
filteredDatasetsCache *filteredRepoCache
|
||||||
|
|
||||||
|
repoInfoCache repoInfoCache
|
||||||
|
|
||||||
|
git *git.Git
|
||||||
|
|
||||||
|
scanOptions *git.ScanOptions
|
||||||
|
|
||||||
|
apiClient *HFClient
|
||||||
|
log logr.Logger
|
||||||
|
conn *sourcespb.Huggingface
|
||||||
|
jobPool *errgroup.Group
|
||||||
|
resumeInfoMutex sync.Mutex
|
||||||
|
resumeInfoSlice []string
|
||||||
|
|
||||||
|
skipAllModels bool
|
||||||
|
skipAllSpaces bool
|
||||||
|
skipAllDatasets bool
|
||||||
|
includeDiscussions bool
|
||||||
|
includePrs bool
|
||||||
|
|
||||||
|
sources.Progress
|
||||||
|
sources.CommonSourceUnitUnmarshaller
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the Source satisfies the interfaces at compile time
|
||||||
|
var _ sources.Source = (*Source)(nil)
|
||||||
|
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
|
||||||
|
|
||||||
|
// WithCustomContentWriter sets the useCustomContentWriter flag on the source.
|
||||||
|
func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true }
|
||||||
|
|
||||||
|
// Type returns the type of source.
|
||||||
|
// It is used for matching source types in configuration and job input.
|
||||||
|
func (s *Source) Type() sourcespb.SourceType {
|
||||||
|
return SourceType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) SourceID() sources.SourceID {
|
||||||
|
return s.sourceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) JobID() sources.JobID {
|
||||||
|
return s.jobID
|
||||||
|
}
|
||||||
|
|
||||||
|
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
|
||||||
|
// based on include and exclude globs.
|
||||||
|
type filteredRepoCache struct {
|
||||||
|
cache.Cache[string]
|
||||||
|
include, exclude []glob.Glob
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
|
||||||
|
includeGlobs := make([]glob.Glob, 0, len(include))
|
||||||
|
excludeGlobs := make([]glob.Glob, 0, len(exclude))
|
||||||
|
for _, ig := range include {
|
||||||
|
g, err := glob.Compile(ig)
|
||||||
|
if err != nil {
|
||||||
|
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
includeGlobs = append(includeGlobs, g)
|
||||||
|
}
|
||||||
|
for _, eg := range exclude {
|
||||||
|
g, err := glob.Compile(eg)
|
||||||
|
if err != nil {
|
||||||
|
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
excludeGlobs = append(excludeGlobs, g)
|
||||||
|
}
|
||||||
|
return &filteredRepoCache{Cache: c, include: includeGlobs, exclude: excludeGlobs}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set overrides the cache.Cache Set method to filter out repos based on
|
||||||
|
// include and exclude globs.
|
||||||
|
func (c *filteredRepoCache) Set(key, val string) {
|
||||||
|
if c.ignoreRepo(key) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !c.includeRepo(key) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Cache.Set(key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *filteredRepoCache) ignoreRepo(s string) bool {
|
||||||
|
for _, g := range c.exclude {
|
||||||
|
if g.Match(s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *filteredRepoCache) includeRepo(s string) bool {
|
||||||
|
if len(c.include) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range c.include {
|
||||||
|
if g.Match(s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init returns an initialized HuggingFace source.
|
||||||
|
func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
|
||||||
|
err := git.CmdCheck()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log = aCtx.Logger()
|
||||||
|
|
||||||
|
s.name = name
|
||||||
|
s.sourceID = sourceID
|
||||||
|
s.jobID = jobID
|
||||||
|
s.verify = verify
|
||||||
|
s.jobPool = &errgroup.Group{}
|
||||||
|
s.jobPool.SetLimit(concurrency)
|
||||||
|
|
||||||
|
var conn sourcespb.Huggingface
|
||||||
|
err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error unmarshalling connection: %w", err)
|
||||||
|
}
|
||||||
|
s.conn = &conn
|
||||||
|
|
||||||
|
s.orgsCache = memory.New[string]()
|
||||||
|
for _, org := range s.conn.Organizations {
|
||||||
|
s.orgsCache.Set(org, org)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.usersCache = memory.New[string]()
|
||||||
|
for _, user := range s.conn.Users {
|
||||||
|
s.usersCache.Set(user, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Verify ignore and include models, spaces, and datasets are valid
|
||||||
|
// this ensures that calling --org <org> --ignore-model <org/model> contains the proper
|
||||||
|
// repo format of org/model. Otherwise, we would scan the entire org.
|
||||||
|
if err := s.validateIgnoreIncludeRepos(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.filteredModelsCache = s.newFilteredRepoCache(memory.New[string](),
|
||||||
|
append(s.conn.GetModels(), s.conn.GetIncludeModels()...),
|
||||||
|
s.conn.GetIgnoreModels(),
|
||||||
|
)
|
||||||
|
|
||||||
|
s.filteredSpacesCache = s.newFilteredRepoCache(memory.New[string](),
|
||||||
|
append(s.conn.GetSpaces(), s.conn.GetIncludeSpaces()...),
|
||||||
|
s.conn.GetIgnoreSpaces(),
|
||||||
|
)
|
||||||
|
|
||||||
|
s.filteredDatasetsCache = s.newFilteredRepoCache(memory.New[string](),
|
||||||
|
append(s.conn.GetDatasets(), s.conn.GetIncludeDatasets()...),
|
||||||
|
s.conn.GetIgnoreDatasets(),
|
||||||
|
)
|
||||||
|
|
||||||
|
s.models = initializeRepos(s.filteredModelsCache, s.conn.Models, fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s"))
|
||||||
|
s.spaces = initializeRepos(s.filteredSpacesCache, s.conn.Spaces, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s"))
|
||||||
|
s.datasets = initializeRepos(s.filteredDatasetsCache, s.conn.Datasets, fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s"))
|
||||||
|
s.repoInfoCache = newRepoInfoCache()
|
||||||
|
|
||||||
|
s.includeDiscussions = s.conn.IncludeDiscussions
|
||||||
|
s.includePrs = s.conn.IncludePrs
|
||||||
|
|
||||||
|
cfg := &git.Config{
|
||||||
|
SourceName: s.name,
|
||||||
|
JobID: s.jobID,
|
||||||
|
SourceID: s.sourceID,
|
||||||
|
SourceType: s.Type(),
|
||||||
|
Verify: s.verify,
|
||||||
|
Concurrency: concurrency,
|
||||||
|
SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData {
|
||||||
|
return &source_metadatapb.MetaData{
|
||||||
|
Data: &source_metadatapb.MetaData_Huggingface{
|
||||||
|
Huggingface: &source_metadatapb.Huggingface{
|
||||||
|
Commit: sanitizer.UTF8(commit),
|
||||||
|
File: sanitizer.UTF8(file),
|
||||||
|
Email: sanitizer.UTF8(email),
|
||||||
|
Repository: sanitizer.UTF8(repository),
|
||||||
|
Link: giturl.GenerateLink(repository, commit, file, line),
|
||||||
|
Timestamp: sanitizer.UTF8(timestamp),
|
||||||
|
Line: line,
|
||||||
|
Visibility: s.visibilityOf(aCtx, repository),
|
||||||
|
ResourceType: s.getResourceType(aCtx, repository),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
UseCustomContentWriter: s.useCustomContentWriter,
|
||||||
|
}
|
||||||
|
s.git = git.NewGit(cfg)
|
||||||
|
|
||||||
|
s.huggingfaceToken = s.conn.GetToken()
|
||||||
|
s.apiClient = NewHFClient(s.conn.Endpoint, s.huggingfaceToken, 10*time.Second)
|
||||||
|
|
||||||
|
s.skipAllModels = s.conn.SkipAllModels
|
||||||
|
s.skipAllSpaces = s.conn.SkipAllSpaces
|
||||||
|
s.skipAllDatasets = s.conn.SkipAllDatasets
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) validateIgnoreIncludeRepos() error {
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IgnoreModels); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IncludeModels); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IgnoreSpaces); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IncludeSpaces); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IgnoreDatasets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := verifySlashSeparatedStrings(s.conn.IncludeDatasets); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifySlashSeparatedStrings(s []string) error {
|
||||||
|
for _, str := range s {
|
||||||
|
if !strings.Contains(str, "/") {
|
||||||
|
return fmt.Errorf("invalid owner/repo: %s", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initializeRepos(cache *filteredRepoCache, repos []string, urlPattern string) []string {
|
||||||
|
returnRepos := make([]string, 0)
|
||||||
|
for _, repo := range repos {
|
||||||
|
if !cache.ignoreRepo(repo) {
|
||||||
|
url := fmt.Sprintf(urlPattern, repo)
|
||||||
|
cache.Set(repo, url)
|
||||||
|
returnRepos = append(returnRepos, repo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return returnRepos
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) getResourceType(ctx context.Context, repoURL string) string {
|
||||||
|
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||||
|
if !ok {
|
||||||
|
// This should never happen.
|
||||||
|
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||||
|
ctx.Logger().Error(err, "failed to get repository resource type")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(repoInfo.resourceType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metadatapb.Visibility {
|
||||||
|
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||||
|
if !ok {
|
||||||
|
// This should never happen.
|
||||||
|
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||||
|
ctx.Logger().Error(err, "failed to get repository visibility")
|
||||||
|
return source_metadatapb.Visibility_unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
return repoInfo.visibility
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunks emits chunks of bytes over a channel.
|
||||||
|
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
|
||||||
|
err := s.enumerate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.scan(ctx, chunksChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) enumerate(ctx context.Context) error {
|
||||||
|
s.enumerateAuthors(ctx)
|
||||||
|
|
||||||
|
s.models = make([]string, 0, s.filteredModelsCache.Count())
|
||||||
|
for _, repo := range s.filteredModelsCache.Keys() {
|
||||||
|
if err := s.cacheRepoInfo(ctx, repo, MODEL, s.filteredModelsCache); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.spaces = make([]string, 0, s.filteredSpacesCache.Count())
|
||||||
|
for _, repo := range s.filteredSpacesCache.Keys() {
|
||||||
|
if err := s.cacheRepoInfo(ctx, repo, SPACE, s.filteredSpacesCache); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.datasets = make([]string, 0, s.filteredDatasetsCache.Count())
|
||||||
|
for _, repo := range s.filteredDatasetsCache.Keys() {
|
||||||
|
if err := s.cacheRepoInfo(ctx, repo, DATASET, s.filteredDatasetsCache); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Info("Completed enumeration", "num_models", len(s.models), "num_spaces", len(s.spaces), "num_datasets", len(s.datasets))
|
||||||
|
|
||||||
|
// We must sort the repos so we can resume later if necessary.
|
||||||
|
sort.Strings(s.models)
|
||||||
|
sort.Strings(s.datasets)
|
||||||
|
sort.Strings(s.spaces)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) cacheRepoInfo(ctx context.Context, repo string, repoType string, repoCache *filteredRepoCache) error {
|
||||||
|
repoURL, _ := repoCache.Get(repo)
|
||||||
|
repoCtx := context.WithValue(ctx, repoType, repoURL)
|
||||||
|
|
||||||
|
if _, ok := s.repoInfoCache.get(repoURL); !ok {
|
||||||
|
repoCtx.Logger().V(2).Info("Caching " + repoType + " info")
|
||||||
|
repo, err := s.apiClient.GetRepo(repoCtx, repo, repoType)
|
||||||
|
if err != nil {
|
||||||
|
repoCtx.Logger().Error(err, "failed to fetch "+repoType)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// check if repo empty
|
||||||
|
if repo.RepoID == "" {
|
||||||
|
repoCtx.Logger().Error(fmt.Errorf("no repo found for repo"), repoURL)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.repoInfoCache.put(repoURL, repoInfo{
|
||||||
|
owner: repo.Owner,
|
||||||
|
name: strings.Split(repo.RepoID, "/")[1],
|
||||||
|
fullName: repo.RepoID,
|
||||||
|
visibility: getVisibility(repo.IsPrivate),
|
||||||
|
resourceType: resourceType(repoType),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
s.updateRepoLists(repoURL, repoType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getVisibility(isPrivate bool) source_metadatapb.Visibility {
|
||||||
|
if isPrivate {
|
||||||
|
return source_metadatapb.Visibility_private
|
||||||
|
}
|
||||||
|
return source_metadatapb.Visibility_public
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) updateRepoLists(repoURL string, repoType string) {
|
||||||
|
switch repoType {
|
||||||
|
case MODEL:
|
||||||
|
s.models = append(s.models, repoURL)
|
||||||
|
case SPACE:
|
||||||
|
s.spaces = append(s.spaces, repoURL)
|
||||||
|
case DATASET:
|
||||||
|
s.datasets = append(s.datasets, repoURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) fetchAndCacheRepos(ctx context.Context, resourceType string, org string) error {
|
||||||
|
var repos []Repo
|
||||||
|
var err error
|
||||||
|
var url string
|
||||||
|
var filteredCache *filteredRepoCache
|
||||||
|
switch resourceType {
|
||||||
|
case MODEL:
|
||||||
|
filteredCache = s.filteredModelsCache
|
||||||
|
url = fmt.Sprintf("%s/%s.git", s.conn.Endpoint, "%s")
|
||||||
|
repos, err = s.apiClient.ListReposByAuthor(ctx, MODEL, org)
|
||||||
|
case SPACE:
|
||||||
|
filteredCache = s.filteredSpacesCache
|
||||||
|
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, SpacesRoute, "%s")
|
||||||
|
repos, err = s.apiClient.ListReposByAuthor(ctx, SPACE, org)
|
||||||
|
case DATASET:
|
||||||
|
filteredCache = s.filteredDatasetsCache
|
||||||
|
url = fmt.Sprintf("%s/%s/%s.git", s.conn.Endpoint, DatasetsRoute, "%s")
|
||||||
|
repos, err = s.apiClient.ListReposByAuthor(ctx, DATASET, org)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, repo := range repos {
|
||||||
|
repoURL := fmt.Sprintf(url, repo.RepoID)
|
||||||
|
filteredCache.Set(repo.RepoID, repoURL)
|
||||||
|
if err := s.cacheRepoInfo(ctx, repo.RepoID, resourceType, filteredCache); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) enumerateAuthors(ctx context.Context) {
|
||||||
|
for _, org := range s.orgsCache.Keys() {
|
||||||
|
orgCtx := context.WithValue(ctx, "organization", org)
|
||||||
|
if !s.skipAllModels {
|
||||||
|
if err := s.fetchAndCacheRepos(orgCtx, MODEL, org); err != nil {
|
||||||
|
orgCtx.Logger().Error(err, "Failed to fetch models for organization")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !s.skipAllSpaces {
|
||||||
|
if err := s.fetchAndCacheRepos(orgCtx, SPACE, org); err != nil {
|
||||||
|
orgCtx.Logger().Error(err, "Failed to fetch spaces for organization")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !s.skipAllDatasets {
|
||||||
|
if err := s.fetchAndCacheRepos(orgCtx, DATASET, org); err != nil {
|
||||||
|
orgCtx.Logger().Error(err, "Failed to fetch datasets for organization")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, user := range s.usersCache.Keys() {
|
||||||
|
userCtx := context.WithValue(ctx, "user", user)
|
||||||
|
if !s.skipAllModels {
|
||||||
|
if err := s.fetchAndCacheRepos(userCtx, MODEL, user); err != nil {
|
||||||
|
userCtx.Logger().Error(err, "Failed to fetch models for user")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !s.skipAllSpaces {
|
||||||
|
if err := s.fetchAndCacheRepos(userCtx, SPACE, user); err != nil {
|
||||||
|
userCtx.Logger().Error(err, "Failed to fetch spaces for user")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !s.skipAllDatasets {
|
||||||
|
if err := s.fetchAndCacheRepos(userCtx, DATASET, user); err != nil {
|
||||||
|
userCtx.Logger().Error(err, "Failed to fetch datasets for user")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk, resourceType string) error {
|
||||||
|
var scannedCount uint64 = 1
|
||||||
|
|
||||||
|
repos := s.getReposListByType(resourceType)
|
||||||
|
|
||||||
|
s.log.V(2).Info("Found "+resourceType+" to scan", "count", len(repos))
|
||||||
|
|
||||||
|
// If there is resume information available, limit this scan to only the repos that still need scanning.
|
||||||
|
reposToScan, progressIndexOffset := sources.FilterReposToResume(repos, s.GetProgress().EncodedResumeInfo)
|
||||||
|
repos = reposToScan
|
||||||
|
|
||||||
|
scanErrs := sources.NewScanErrors()
|
||||||
|
|
||||||
|
if s.scanOptions == nil {
|
||||||
|
s.scanOptions = &git.ScanOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, repoURL := range repos {
|
||||||
|
i, repoURL := i, repoURL
|
||||||
|
s.jobPool.Go(func() error {
|
||||||
|
if common.IsDone(ctx) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: set progress complete is being called concurrently with i
|
||||||
|
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL, resourceType, repos)
|
||||||
|
// Ensure the repo is removed from the resume info after being scanned.
|
||||||
|
defer func(s *Source, repoURL string) {
|
||||||
|
s.resumeInfoMutex.Lock()
|
||||||
|
defer s.resumeInfoMutex.Unlock()
|
||||||
|
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
|
||||||
|
}(s, repoURL)
|
||||||
|
|
||||||
|
// Scan the repository
|
||||||
|
repoInfo, ok := s.repoInfoCache.get(repoURL)
|
||||||
|
if !ok {
|
||||||
|
// This should never happen.
|
||||||
|
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||||
|
s.log.Error(err, "failed to scan "+resourceType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
repoCtx := context.WithValues(ctx, resourceType, repoURL)
|
||||||
|
duration, err := s.cloneAndScanRepo(repoCtx, repoURL, repoInfo, chunksChan)
|
||||||
|
if err != nil {
|
||||||
|
scanErrs.Add(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan discussions and PRs, if enabled.
|
||||||
|
if s.includeDiscussions || s.includePrs {
|
||||||
|
if err = s.scanDiscussions(repoCtx, repoInfo, chunksChan); err != nil {
|
||||||
|
scanErrs.Add(fmt.Errorf("error scanning discussions/PRs in repo %s: %w", repoURL, err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
repoCtx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d "+resourceType+"s", scannedCount, len(s.models)), "duration_seconds", duration)
|
||||||
|
atomic.AddUint64(&scannedCount, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = s.jobPool.Wait()
|
||||||
|
if scanErrs.Count() > 0 {
|
||||||
|
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
|
||||||
|
}
|
||||||
|
s.SetProgressComplete(len(repos), len(repos), "Completed HuggingFace "+resourceType+" scan", "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) getReposListByType(resourceType string) []string {
|
||||||
|
switch resourceType {
|
||||||
|
case MODEL:
|
||||||
|
return s.models
|
||||||
|
case SPACE:
|
||||||
|
return s.spaces
|
||||||
|
case DATASET:
|
||||||
|
return s.datasets
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||||
|
if err := s.scanRepos(ctx, chunksChan, MODEL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.scanRepos(ctx, chunksChan, SPACE); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.scanRepos(ctx, chunksChan, DATASET); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) cloneAndScanRepo(ctx context.Context, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
|
||||||
|
ctx.Logger().V(2).Info("attempting to clone %s", repoInfo.resourceType)
|
||||||
|
path, repo, err := s.cloneRepo(ctx, repoURL)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
var logger logr.Logger
|
||||||
|
logger.V(2).Info("scanning %s", repoInfo.resourceType)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil {
|
||||||
|
return 0, fmt.Errorf("error scanning repo %s: %w", repoURL, err)
|
||||||
|
}
|
||||||
|
return time.Since(start), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setProgressCompleteWithRepo calls the s.SetProgressComplete after safely setting up the encoded resume info string.
|
||||||
|
func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL string, resourceType string, repos []string) {
|
||||||
|
s.resumeInfoMutex.Lock()
|
||||||
|
defer s.resumeInfoMutex.Unlock()
|
||||||
|
|
||||||
|
// Add the repoURL to the resume info slice.
|
||||||
|
s.resumeInfoSlice = append(s.resumeInfoSlice, repoURL)
|
||||||
|
sort.Strings(s.resumeInfoSlice)
|
||||||
|
|
||||||
|
// Make the resume info string from the slice.
|
||||||
|
encodedResumeInfo := sources.EncodeResumeInfo(s.resumeInfoSlice)
|
||||||
|
s.SetProgressComplete(index+offset, len(repos)+offset, fmt.Sprintf("%ss: %s", resourceType, repoURL), encodedResumeInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) scanDiscussions(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
|
||||||
|
discussions, err := s.apiClient.ListDiscussions(ctx, repoInfo)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, discussion := range discussions.Discussions {
|
||||||
|
if (discussion.IsPR && s.includePrs) || (!discussion.IsPR && s.includeDiscussions) {
|
||||||
|
d, err := s.apiClient.GetDiscussionByID(ctx, repoInfo, discussion.GetID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Note: there is no discussion "description" or similar to chunk, only comments
|
||||||
|
if err = s.chunkDiscussionComments(ctx, repoInfo, d, chunksChan); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) chunkDiscussionComments(ctx context.Context, repoInfo repoInfo, discussion Discussion, chunksChan chan *sources.Chunk) error {
|
||||||
|
for _, comment := range discussion.Events {
|
||||||
|
chunk := &sources.Chunk{
|
||||||
|
SourceName: s.name,
|
||||||
|
SourceID: s.SourceID(),
|
||||||
|
JobID: s.JobID(),
|
||||||
|
SourceType: s.Type(),
|
||||||
|
SourceMetadata: &source_metadatapb.MetaData{
|
||||||
|
Data: &source_metadatapb.MetaData_Huggingface{
|
||||||
|
Huggingface: &source_metadatapb.Huggingface{
|
||||||
|
Link: sanitizer.UTF8(fmt.Sprintf("%s/%s#%s", s.conn.Endpoint, discussion.GetDiscussionPath(), comment.GetID())),
|
||||||
|
Username: sanitizer.UTF8(comment.GetAuthor()),
|
||||||
|
Repository: sanitizer.UTF8(fmt.Sprintf("%s/%s", s.conn.Endpoint, discussion.GetGitPath())),
|
||||||
|
Timestamp: sanitizer.UTF8(comment.GetCreatedAt()),
|
||||||
|
Visibility: repoInfo.visibility,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Data: []byte(comment.Data.Latest.Raw),
|
||||||
|
Verify: s.verify,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case chunksChan <- chunk:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
651
pkg/sources/huggingface/huggingface_client_test.go
Normal file
651
pkg/sources/huggingface/huggingface_client_test.go
Normal file
|
@ -0,0 +1,651 @@
|
||||||
|
package huggingface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/h2non/gock.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TEST_TOKEN = "test token"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initTestClient() *HFClient {
|
||||||
|
return NewHFClient("https://huggingface.co", TEST_TOKEN, 10*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepo(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
repoName := "test-model"
|
||||||
|
repoOwner := "test-author"
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": repoOwner + "/" + repoName,
|
||||||
|
"author": repoOwner,
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, model)
|
||||||
|
assert.Equal(t, repoOwner+"/"+repoName, model.RepoID)
|
||||||
|
assert.Equal(t, repoOwner, model.Owner)
|
||||||
|
assert.Equal(t, true, model.IsPrivate)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepo_NotFound(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
repoName := "doesnotexist"
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(404).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "",
|
||||||
|
"author": "",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, model)
|
||||||
|
assert.Equal(t, "", model.RepoID)
|
||||||
|
assert.Equal(t, "", model.Owner)
|
||||||
|
assert.Equal(t, false, model.IsPrivate)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModel_Error(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
repoName := "doesnotexist"
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(resourceType)+"/"+repoName).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(500)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
model, err := client.GetRepo(context.Background(), repoName, resourceType)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.NotNil(t, model)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDiscussions(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/test-model",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBlob := `{
|
||||||
|
"discussions": [
|
||||||
|
{
|
||||||
|
"num": 2,
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false
|
||||||
|
},
|
||||||
|
"repo": {
|
||||||
|
"name": "test-author/test-model",
|
||||||
|
"type": "model"
|
||||||
|
},
|
||||||
|
"title": "new PR",
|
||||||
|
"status": "open",
|
||||||
|
"createdAt": "2024-06-18T14:34:21.000Z",
|
||||||
|
"isPullRequest": true,
|
||||||
|
"numComments": 2,
|
||||||
|
"pinned": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"num": 1,
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false
|
||||||
|
},
|
||||||
|
"repo": {
|
||||||
|
"name": "test-author/test-model",
|
||||||
|
"type": "model"
|
||||||
|
},
|
||||||
|
"title": "secret in comment",
|
||||||
|
"status": "closed",
|
||||||
|
"createdAt": "2024-06-18T14:31:57.000Z",
|
||||||
|
"isPullRequest": false,
|
||||||
|
"numComments": 2,
|
||||||
|
"pinned": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 2,
|
||||||
|
"start": 0,
|
||||||
|
"numClosedDiscussions": 1
|
||||||
|
}`
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(200).
|
||||||
|
JSON(jsonBlob)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, discussions)
|
||||||
|
assert.Equal(t, 2, len(discussions.Discussions))
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDiscussions_NotFound(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/doesnotexist",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(404).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"discussions": []map[string]interface{}{},
|
||||||
|
})
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, discussions)
|
||||||
|
assert.Equal(t, 0, len(discussions.Discussions))
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListDiscussions_Error(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/doesnotexist",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(500)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussions, err := client.ListDiscussions(context.Background(), repoInfo)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.NotNil(t, discussions)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionByID(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/test-model",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
discussionID := "1"
|
||||||
|
|
||||||
|
jsonBlob := `{
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false
|
||||||
|
},
|
||||||
|
"num": 1,
|
||||||
|
"repo": {
|
||||||
|
"name": "test-author/test-model",
|
||||||
|
"type": "model"
|
||||||
|
},
|
||||||
|
"title": "secret in initial",
|
||||||
|
"status": "open",
|
||||||
|
"createdAt": "2024-06-18T14:31:46.000Z",
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"id": "525",
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false,
|
||||||
|
"isOwner": true,
|
||||||
|
"isOrgMember": false
|
||||||
|
},
|
||||||
|
"createdAt": "2024-06-18T14:31:46.000Z",
|
||||||
|
"type": "comment",
|
||||||
|
"data": {
|
||||||
|
"edited": true,
|
||||||
|
"hidden": false,
|
||||||
|
"latest": {
|
||||||
|
"raw": "dd",
|
||||||
|
"html": "<p>dd</p>\n",
|
||||||
|
"updatedAt": "2024-06-18T14:33:32.066Z",
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"numEdits": 1,
|
||||||
|
"editors": ["trufflej"],
|
||||||
|
"reactions": [],
|
||||||
|
"identifiedLanguage": {
|
||||||
|
"language": "en",
|
||||||
|
"probability": 0.40104949474334717
|
||||||
|
},
|
||||||
|
"isReport": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "526",
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false,
|
||||||
|
"isOwner": true,
|
||||||
|
"isOrgMember": false
|
||||||
|
},
|
||||||
|
"createdAt": "2024-06-18T14:32:40.000Z",
|
||||||
|
"type": "status-change",
|
||||||
|
"data": {
|
||||||
|
"status": "closed"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "527",
|
||||||
|
"author": {
|
||||||
|
"avatarUrl": "/avatars/test.svg",
|
||||||
|
"fullname": "TEST",
|
||||||
|
"name": "test-author",
|
||||||
|
"type": "user",
|
||||||
|
"isPro": false,
|
||||||
|
"isHf": false,
|
||||||
|
"isMod": false,
|
||||||
|
"isOwner": true,
|
||||||
|
"isOrgMember": false
|
||||||
|
},
|
||||||
|
"createdAt": "2024-06-18T14:33:27.000Z",
|
||||||
|
"type": "status-change",
|
||||||
|
"data": {
|
||||||
|
"status": "open"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pinned": false,
|
||||||
|
"locked": false,
|
||||||
|
"isPullRequest": false,
|
||||||
|
"isReport": false
|
||||||
|
}`
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(200).
|
||||||
|
JSON(jsonBlob)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, discussion)
|
||||||
|
assert.Equal(t, discussionID, strconv.Itoa(discussion.ID))
|
||||||
|
assert.Equal(t, 3, len(discussion.Events))
|
||||||
|
assert.Equal(t, false, discussion.IsPR)
|
||||||
|
assert.Equal(t, "secret in initial", discussion.Title)
|
||||||
|
assert.Equal(t, repoInfo.fullName, discussion.Repo.FullName)
|
||||||
|
assert.Equal(t, string(repoInfo.resourceType), discussion.Repo.ResourceType)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionByID_NotFound(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/test-model",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
discussionID := "doesnotexist"
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(404).
|
||||||
|
JSON(map[string]interface{}{})
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, discussion)
|
||||||
|
assert.Equal(t, 0, len(discussion.Events))
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionByID_Error(t *testing.T) {
|
||||||
|
repoInfo := repoInfo{
|
||||||
|
fullName: "test-author/test-model",
|
||||||
|
resourceType: MODEL,
|
||||||
|
}
|
||||||
|
discussionID := "doesnotexist"
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/"+APIRoute+"/"+getResourceAPIPath(string(repoInfo.resourceType))+"/"+repoInfo.fullName+"/"+DiscussionsRoute+"/"+discussionID).
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(500)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
discussion, err := client.GetDiscussionByID(context.Background(), repoInfo, discussionID)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.NotNil(t, discussion)
|
||||||
|
assert.Equal(t, "", discussion.Title)
|
||||||
|
assert.Equal(t, 0, len(discussion.Events))
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListReposByAuthor(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
author := "test-author"
|
||||||
|
repo := "test-model"
|
||||||
|
repo2 := "test-model2"
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||||
|
MatchParam("author", author).
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"id": author + "/" + repo,
|
||||||
|
"modelId": author + "/" + repo,
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_id": "2",
|
||||||
|
"id": author + "/" + repo2,
|
||||||
|
"modelId": author + "/" + repo2,
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, mock := range gock.Pending() {
|
||||||
|
fmt.Println(mock.Request().URLStruct.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, repos)
|
||||||
|
assert.Equal(t, 2, len(repos))
|
||||||
|
// count of repos with private flag
|
||||||
|
countOfPrivateRepos := 0
|
||||||
|
for _, repo := range repos {
|
||||||
|
if repo.IsPrivate {
|
||||||
|
countOfPrivateRepos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, countOfPrivateRepos)
|
||||||
|
// there is no author field in JSON, so assert repo.Owner is empty
|
||||||
|
for _, repo := range repos {
|
||||||
|
assert.Equal(t, "", repo.Owner)
|
||||||
|
}
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListReposByAuthor_NotFound(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
author := "authordoesntexist"
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||||
|
MatchParam("author", author).
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(404).
|
||||||
|
JSON([]map[string]interface{}{})
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, repos)
|
||||||
|
assert.Equal(t, 0, len(repos))
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListReposByAuthor_Error(t *testing.T) {
|
||||||
|
resourceType := MODEL
|
||||||
|
author := "doesnotexist"
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(resourceType))).
|
||||||
|
MatchParam("author", author).
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
MatchHeader("Authorization", "Bearer "+TEST_TOKEN).
|
||||||
|
Reply(500)
|
||||||
|
|
||||||
|
client := initTestClient()
|
||||||
|
repos, err := client.ListReposByAuthor(context.Background(), resourceType, author)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, repos)
|
||||||
|
assert.False(t, gock.HasUnmatchedRequest())
|
||||||
|
assert.True(t, gock.IsDone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetResourceAPIPath(t *testing.T) {
|
||||||
|
assert.Equal(t, "models", getResourceAPIPath(MODEL))
|
||||||
|
assert.Equal(t, "datasets", getResourceAPIPath(DATASET))
|
||||||
|
assert.Equal(t, "spaces", getResourceAPIPath(SPACE))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetResourceHTMLPath(t *testing.T) {
|
||||||
|
assert.Equal(t, "", getResourceHTMLPath(MODEL))
|
||||||
|
assert.Equal(t, "datasets", getResourceHTMLPath(DATASET))
|
||||||
|
assert.Equal(t, "spaces", getResourceHTMLPath(SPACE))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAPIURL_ValidInputs(t *testing.T) {
|
||||||
|
endpoint := "https://huggingface.co"
|
||||||
|
resourceType := MODEL
|
||||||
|
repoName := "test-repo"
|
||||||
|
|
||||||
|
expectedURL := "https://huggingface.co/api/models/test-repo"
|
||||||
|
|
||||||
|
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, expectedURL, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAPIURL_EmptyEndpoint(t *testing.T) {
|
||||||
|
endpoint := ""
|
||||||
|
resourceType := MODEL
|
||||||
|
repoName := "test-repo"
|
||||||
|
|
||||||
|
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, "", url)
|
||||||
|
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAPIURL_EmptyResourceType(t *testing.T) {
|
||||||
|
endpoint := "https://huggingface.co"
|
||||||
|
resourceType := ""
|
||||||
|
repoName := "test-repo"
|
||||||
|
|
||||||
|
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, "", url)
|
||||||
|
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAPIURL_EmptyRepoName(t *testing.T) {
|
||||||
|
endpoint := "https://huggingface.co"
|
||||||
|
resourceType := "model"
|
||||||
|
repoName := ""
|
||||||
|
|
||||||
|
url, err := buildAPIURL(endpoint, resourceType, repoName)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, "", url)
|
||||||
|
assert.Equal(t, "endpoint, resourceType, and repoName must not be empty", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionPath_ModelResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-model",
|
||||||
|
ResourceType: "model",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "test-author/test-model/discussions/1"
|
||||||
|
|
||||||
|
path := discussion.GetDiscussionPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionPath_DatasetResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-dataset",
|
||||||
|
ResourceType: "dataset",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "datasets/test-author/test-dataset/discussions/1"
|
||||||
|
|
||||||
|
path := discussion.GetDiscussionPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiscussionPath_SpaceResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-space",
|
||||||
|
ResourceType: "space",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "spaces/test-author/test-space/discussions/1"
|
||||||
|
|
||||||
|
path := discussion.GetDiscussionPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetGitPath_ModelResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-model",
|
||||||
|
ResourceType: "model",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "test-author/test-model.git"
|
||||||
|
|
||||||
|
path := discussion.GetGitPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetGitPath_DatasetResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-dataset",
|
||||||
|
ResourceType: "dataset",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "datasets/test-author/test-dataset.git"
|
||||||
|
|
||||||
|
path := discussion.GetGitPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetGitPath_SpaceResource(t *testing.T) {
|
||||||
|
discussion := Discussion{
|
||||||
|
Repo: RepoData{
|
||||||
|
FullName: "test-author/test-space",
|
||||||
|
ResourceType: "space",
|
||||||
|
},
|
||||||
|
ID: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedPath := "spaces/test-author/test-space.git"
|
||||||
|
|
||||||
|
path := discussion.GetGitPath()
|
||||||
|
assert.Equal(t, expectedPath, path)
|
||||||
|
}
|
986
pkg/sources/huggingface/huggingface_test.go
Normal file
986
pkg/sources/huggingface/huggingface_test.go
Normal file
|
@ -0,0 +1,986 @@
|
||||||
|
package huggingface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"gopkg.in/h2non/gock.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestSource(src *sourcespb.Huggingface) (*Source, *anypb.Any) {
|
||||||
|
s := &Source{}
|
||||||
|
conn, err := anypb.New(src)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return s, conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// test include exclude ignore/include orgs, users
|
||||||
|
|
||||||
|
func TestInit(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Models: []string{"user/model1", "user/model2", "user/ignorethismodel"},
|
||||||
|
IgnoreModels: []string{"user/ignorethismodel"},
|
||||||
|
Spaces: []string{"user/space1", "user/space2", "user/ignorethisspace"},
|
||||||
|
IgnoreSpaces: []string{"user/ignorethisspace"},
|
||||||
|
Datasets: []string{"user/dataset1", "user/dataset2", "user/ignorethisdataset"},
|
||||||
|
IgnoreDatasets: []string{"user/ignorethisdataset"},
|
||||||
|
Organizations: []string{"org1", "org2"},
|
||||||
|
Users: []string{"user1", "user2"},
|
||||||
|
})
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, []string{"user/model1", "user/model2"}, s.models)
|
||||||
|
for _, model := range s.models {
|
||||||
|
modelURL, _ := s.filteredModelsCache.Get(model)
|
||||||
|
assert.Equal(t, modelURL, s.conn.Endpoint+"/"+model+".git")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, []string{"user/space1", "user/space2"}, s.spaces)
|
||||||
|
for _, space := range s.spaces {
|
||||||
|
spaceURL, _ := s.filteredSpacesCache.Get(space)
|
||||||
|
assert.Equal(t, spaceURL, s.conn.Endpoint+"/"+getResourceHTMLPath(SPACE)+"/"+space+".git")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, []string{"user/dataset1", "user/dataset2"}, s.datasets)
|
||||||
|
for _, dataset := range s.datasets {
|
||||||
|
datasetURL, _ := s.filteredDatasetsCache.Get(dataset)
|
||||||
|
assert.Equal(t, datasetURL, s.conn.Endpoint+"/"+getResourceHTMLPath(DATASET)+"/"+dataset+".git")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ElementsMatch(t, s.conn.Organizations, s.orgsCache.Keys())
|
||||||
|
assert.ElementsMatch(t, s.conn.Users, s.usersCache.Keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetResourceType(t *testing.T) {
|
||||||
|
repo := "author/model1"
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Models: []string{repo},
|
||||||
|
})
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/author/model1").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "author/model1",
|
||||||
|
"author": "author",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, MODEL, s.getResourceType(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisibilityOf(t *testing.T) {
|
||||||
|
repo := "author/model1"
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Models: []string{repo},
|
||||||
|
})
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/author/model1").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "author/model1",
|
||||||
|
"author": "author",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, source_metadatapb.Visibility(1), s.visibilityOf(context.Background(), (s.conn.Endpoint+"/"+repo+".git")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerate(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Models: []string{"author/model1"},
|
||||||
|
Datasets: []string{"author/dataset1"},
|
||||||
|
Spaces: []string{"author/space1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/author/model1").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "author/model1",
|
||||||
|
"author": "author",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/author/dataset1").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "author/dataset1",
|
||||||
|
"author": "author",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/author/space1").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "author/space1",
|
||||||
|
"author": "author",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/author/model1.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||||
|
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||||
|
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_private,
|
||||||
|
resourceType: MODEL,
|
||||||
|
owner: "author",
|
||||||
|
name: "model1",
|
||||||
|
fullName: "author/model1",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_public,
|
||||||
|
resourceType: DATASET,
|
||||||
|
owner: "author",
|
||||||
|
name: "dataset1",
|
||||||
|
fullName: "author/dataset1",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_public,
|
||||||
|
resourceType: SPACE,
|
||||||
|
owner: "author",
|
||||||
|
name: "space1",
|
||||||
|
fullName: "author/space1",
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateRepoList(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/author/model1.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/author/dataset1.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/author/space1.git"
|
||||||
|
|
||||||
|
s.updateRepoLists(modelGitURL, MODEL)
|
||||||
|
s.updateRepoLists(datasetGitURL, DATASET)
|
||||||
|
s.updateRepoLists(spaceGitURL, SPACE)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||||
|
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||||
|
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetReposListByType(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Models: []string{"author/model1", "author/model2"},
|
||||||
|
Datasets: []string{"author/dataset1", "author/dataset2"},
|
||||||
|
Spaces: []string{"author/space1", "author/space2"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, s.getReposListByType(MODEL), s.models)
|
||||||
|
assert.Equal(t, s.getReposListByType(DATASET), s.datasets)
|
||||||
|
assert.Equal(t, s.getReposListByType(SPACE), s.spaces)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetVisibility(t *testing.T) {
|
||||||
|
assert.Equal(t, source_metadatapb.Visibility(1), getVisibility(true))
|
||||||
|
assert.Equal(t, source_metadatapb.Visibility(0), getVisibility(false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsOrg(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"id": "org/model",
|
||||||
|
"modelId": "org/model",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/org/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/model",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "3",
|
||||||
|
"id": "org/dataset",
|
||||||
|
"modelId": "org/dataset",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/org/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/dataset",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "5",
|
||||||
|
"id": "org/space",
|
||||||
|
"modelId": "org/space",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/org/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/space",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/org/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||||
|
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||||
|
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_private,
|
||||||
|
resourceType: MODEL,
|
||||||
|
owner: "org",
|
||||||
|
name: "model",
|
||||||
|
fullName: "org/model",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_private,
|
||||||
|
resourceType: DATASET,
|
||||||
|
owner: "org",
|
||||||
|
name: "dataset",
|
||||||
|
fullName: "org/dataset",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_private,
|
||||||
|
resourceType: SPACE,
|
||||||
|
owner: "org",
|
||||||
|
name: "space",
|
||||||
|
fullName: "org/space",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsOrgSkipAll(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
SkipAllModels: true,
|
||||||
|
SkipAllDatasets: true,
|
||||||
|
SkipAllSpaces: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"id": "org/model",
|
||||||
|
"modelId": "org/model",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/org/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/model",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "3",
|
||||||
|
"id": "org/dataset",
|
||||||
|
"modelId": "org/dataset",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/org/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/dataset",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "5",
|
||||||
|
"id": "org/space",
|
||||||
|
"modelId": "org/space",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/org/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/space",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/org/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{}, s.models)
|
||||||
|
assert.Equal(t, []string{}, s.datasets)
|
||||||
|
assert.Equal(t, []string{}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsOrgIgnores(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
IgnoreModels: []string{"org/model"},
|
||||||
|
IgnoreDatasets: []string{"org/dataset"},
|
||||||
|
IgnoreSpaces: []string{"org/space"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
// mock the request to the huggingface api
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "1",
|
||||||
|
"id": "org/model",
|
||||||
|
"modelId": "org/model",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/org/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/model",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "3",
|
||||||
|
"id": "org/dataset",
|
||||||
|
"modelId": "org/dataset",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/org/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/dataset",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "org").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "5",
|
||||||
|
"id": "org/space",
|
||||||
|
"modelId": "org/space",
|
||||||
|
"private": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/org/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "org/space",
|
||||||
|
"author": "org",
|
||||||
|
"private": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/org/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/org/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/org/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{}, s.models)
|
||||||
|
assert.Equal(t, []string{}, s.datasets)
|
||||||
|
assert.Equal(t, []string{}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsUser(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Users: []string{"user"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "2",
|
||||||
|
"id": "user/model",
|
||||||
|
"modelId": "user/model",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/user/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/model",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "4",
|
||||||
|
"id": "user/dataset",
|
||||||
|
"modelId": "user/dataset",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/user/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/dataset",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "6",
|
||||||
|
"id": "user/space",
|
||||||
|
"modelId": "user/space",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/user/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/space",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/user/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{modelGitURL}, s.models)
|
||||||
|
assert.Equal(t, []string{datasetGitURL}, s.datasets)
|
||||||
|
assert.Equal(t, []string{spaceGitURL}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_public,
|
||||||
|
resourceType: MODEL,
|
||||||
|
owner: "user",
|
||||||
|
name: "model",
|
||||||
|
fullName: "user/model",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_public,
|
||||||
|
resourceType: DATASET,
|
||||||
|
owner: "user",
|
||||||
|
name: "dataset",
|
||||||
|
fullName: "user/dataset",
|
||||||
|
})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{
|
||||||
|
visibility: source_metadatapb.Visibility_public,
|
||||||
|
resourceType: SPACE,
|
||||||
|
owner: "user",
|
||||||
|
name: "space",
|
||||||
|
fullName: "user/space",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsUserSkipAll(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Users: []string{"user"},
|
||||||
|
SkipAllModels: true,
|
||||||
|
SkipAllDatasets: true,
|
||||||
|
SkipAllSpaces: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "2",
|
||||||
|
"id": "user/model",
|
||||||
|
"modelId": "user/model",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/user/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/model",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "4",
|
||||||
|
"id": "user/dataset",
|
||||||
|
"modelId": "user/dataset",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/user/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/dataset",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "6",
|
||||||
|
"id": "user/space",
|
||||||
|
"modelId": "user/space",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/user/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/space",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/user/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{}, s.models)
|
||||||
|
assert.Equal(t, []string{}, s.datasets)
|
||||||
|
assert.Equal(t, []string{}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumerateAuthorsUserIgnores(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Users: []string{"user"},
|
||||||
|
IgnoreModels: []string{"user/model"},
|
||||||
|
IgnoreDatasets: []string{"user/dataset"},
|
||||||
|
IgnoreSpaces: []string{"user/space"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
defer gock.Off()
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(MODEL))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "2",
|
||||||
|
"id": "user/model",
|
||||||
|
"modelId": "user/model",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/models/user/model").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/model",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(DATASET))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "4",
|
||||||
|
"id": "user/dataset",
|
||||||
|
"modelId": "user/dataset",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/datasets/user/dataset").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/dataset",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get(fmt.Sprintf("/%s/%s", APIRoute, getResourceAPIPath(SPACE))).
|
||||||
|
MatchParam("author", "user").
|
||||||
|
MatchParam("limit", "1000").
|
||||||
|
Reply(200).
|
||||||
|
JSON([]map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_id": "6",
|
||||||
|
"id": "user/space",
|
||||||
|
"modelId": "user/space",
|
||||||
|
"private": false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
gock.New("https://huggingface.co").
|
||||||
|
Get("/api/spaces/user/space").
|
||||||
|
Reply(200).
|
||||||
|
JSON(map[string]interface{}{
|
||||||
|
"id": "user/space",
|
||||||
|
"author": "user",
|
||||||
|
"private": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
err = s.enumerate(context.Background())
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
modelGitURL := "https://huggingface.co/user/model.git"
|
||||||
|
datasetGitURL := "https://huggingface.co/datasets/user/dataset.git"
|
||||||
|
spaceGitURL := "https://huggingface.co/spaces/user/space.git"
|
||||||
|
|
||||||
|
assert.Equal(t, []string{}, s.models)
|
||||||
|
assert.Equal(t, []string{}, s.datasets)
|
||||||
|
assert.Equal(t, []string{}, s.spaces)
|
||||||
|
|
||||||
|
r, _ := s.repoInfoCache.get(modelGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(datasetGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
|
||||||
|
r, _ = s.repoInfoCache.get(spaceGitURL)
|
||||||
|
assert.Equal(t, r, repoInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifySlashSeparatedStrings(t *testing.T) {
|
||||||
|
assert.Error(t, verifySlashSeparatedStrings([]string{"orgmodel"}))
|
||||||
|
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model"}))
|
||||||
|
assert.Error(t, verifySlashSeparatedStrings([]string{"org/model", "orgmodel2"}))
|
||||||
|
assert.NoError(t, verifySlashSeparatedStrings([]string{"org/model", "org/model2"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIgnoreIncludeReposRepos(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
IgnoreModels: []string{"orgmodel1"},
|
||||||
|
IgnoreDatasets: []string{"org/dataset1"},
|
||||||
|
IgnoreSpaces: []string{"org/space1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIgnoreIncludeReposDatasets(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
IgnoreModels: []string{"org/model1"},
|
||||||
|
IgnoreDatasets: []string{"orgdataset1"},
|
||||||
|
IgnoreSpaces: []string{"org/space1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateIgnoreIncludeReposSpaces(t *testing.T) {
|
||||||
|
s, conn := createTestSource(&sourcespb.Huggingface{
|
||||||
|
Endpoint: "https://huggingface.co",
|
||||||
|
Credential: &sourcespb.Huggingface_Token{
|
||||||
|
Token: "super secret token",
|
||||||
|
},
|
||||||
|
Organizations: []string{"org"},
|
||||||
|
IgnoreModels: []string{"org/model1"},
|
||||||
|
IgnoreDatasets: []string{"org/dataset1"},
|
||||||
|
IgnoreSpaces: []string{"orgspace1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Init(context.Background(), "test - huggingface", 0, 1337, false, conn, 1)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// repeat this with all skip flags, and then include/ignore flags
|
72
pkg/sources/huggingface/repo.go
Normal file
72
pkg/sources/huggingface/repo.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package huggingface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
gogit "github.com/go-git/go-git/v5"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
|
||||||
|
)
|
||||||
|
|
||||||
|
type repoInfoCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[string]repoInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRepoInfoCache() repoInfoCache {
|
||||||
|
return repoInfoCache{
|
||||||
|
cache: make(map[string]repoInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repoInfoCache) put(repoURL string, info repoInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.cache[repoURL] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repoInfoCache) get(repoURL string) (repoInfo, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
info, ok := r.cache[repoURL]
|
||||||
|
return info, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
type repoInfo struct {
|
||||||
|
owner string
|
||||||
|
name string
|
||||||
|
fullName string
|
||||||
|
visibility source_metadatapb.Visibility
|
||||||
|
resourceType resourceType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) cloneRepo(
|
||||||
|
ctx context.Context,
|
||||||
|
repoURL string,
|
||||||
|
) (string, *gogit.Repository, error) {
|
||||||
|
var (
|
||||||
|
path string
|
||||||
|
repo *gogit.Repository
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
switch s.conn.GetCredential().(type) {
|
||||||
|
case *sourcespb.Huggingface_Unauthenticated:
|
||||||
|
path, repo, err = git.CloneRepoUsingUnauthenticated(ctx, repoURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
case *sourcespb.Huggingface_Token:
|
||||||
|
path, repo, err = git.CloneRepoUsingToken(ctx, s.huggingfaceToken, repoURL, "")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("unhandled credential type for repo %s: %T", repoURL, s.conn.GetCredential())
|
||||||
|
}
|
||||||
|
return path, repo, nil
|
||||||
|
}
|
|
@ -136,6 +136,19 @@ message GCS {
|
||||||
string content_type = 8;
|
string content_type = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Huggingface {
|
||||||
|
string link = 1;
|
||||||
|
string username = 2;
|
||||||
|
string repository = 3;
|
||||||
|
string commit = 4;
|
||||||
|
string email = 5;
|
||||||
|
string file = 6;
|
||||||
|
string timestamp = 7;
|
||||||
|
int64 line = 8;
|
||||||
|
Visibility visibility = 9;
|
||||||
|
string resource_type = 10;
|
||||||
|
}
|
||||||
|
|
||||||
message Jira {
|
message Jira {
|
||||||
string issue = 1;
|
string issue = 1;
|
||||||
string author = 2;
|
string author = 2;
|
||||||
|
@ -351,5 +364,6 @@ message MetaData {
|
||||||
Postman postman = 29;
|
Postman postman = 29;
|
||||||
Webhook webhook = 30;
|
Webhook webhook = 30;
|
||||||
Elasticsearch elasticsearch = 31;
|
Elasticsearch elasticsearch = 31;
|
||||||
|
Huggingface huggingface = 32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,7 @@ enum SourceType {
|
||||||
SOURCE_TYPE_POSTMAN = 33;
|
SOURCE_TYPE_POSTMAN = 33;
|
||||||
SOURCE_TYPE_WEBHOOK = 34;
|
SOURCE_TYPE_WEBHOOK = 34;
|
||||||
SOURCE_TYPE_ELASTICSEARCH = 35;
|
SOURCE_TYPE_ELASTICSEARCH = 35;
|
||||||
|
SOURCE_TYPE_HUGGINGFACE = 36;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LocalSource {
|
message LocalSource {
|
||||||
|
@ -248,6 +249,30 @@ message GoogleDrive {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Huggingface {
|
||||||
|
string endpoint = 1 [(validate.rules).string.uri_ref = true];
|
||||||
|
oneof credential {
|
||||||
|
string token = 2;
|
||||||
|
credentials.Unauthenticated unauthenticated = 3;
|
||||||
|
}
|
||||||
|
repeated string models = 4;
|
||||||
|
repeated string spaces = 5;
|
||||||
|
repeated string datasets = 12;
|
||||||
|
repeated string organizations = 6;
|
||||||
|
repeated string users = 15;
|
||||||
|
repeated string ignore_models = 7;
|
||||||
|
repeated string include_models = 8;
|
||||||
|
repeated string ignore_spaces = 9;
|
||||||
|
repeated string include_spaces = 10;
|
||||||
|
repeated string ignore_datasets = 13;
|
||||||
|
repeated string include_datasets = 14;
|
||||||
|
bool skip_all_models = 16;
|
||||||
|
bool skip_all_spaces = 17;
|
||||||
|
bool skip_all_datasets = 18;
|
||||||
|
bool include_discussions = 11;
|
||||||
|
bool include_prs = 19;
|
||||||
|
}
|
||||||
|
|
||||||
message JIRA {
|
message JIRA {
|
||||||
string endpoint = 1 [(validate.rules).string.uri_ref = true];
|
string endpoint = 1 [(validate.rules).string.uri_ref = true];
|
||||||
oneof credential {
|
oneof credential {
|
||||||
|
|
Loading…
Reference in a new issue