modifying ids requires augmenting relationships

Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>
This commit is contained in:
Alex Goodman 2024-09-19 16:54:43 -04:00
parent 4f0132a98c
commit 9d281d1d99
5 changed files with 164 additions and 99 deletions

View file

@ -17,20 +17,16 @@ type Index struct {
// NewIndex returns a new relationship Index
func NewIndex(relationships ...artifact.Relationship) *Index {
out := Index{}
out := Index{
fromID: make(map[artifact.ID]*mappedRelationships),
toID: make(map[artifact.ID]*mappedRelationships),
}
out.Add(relationships...)
return &out
}
// Add adds all the given relationships to the index, without adding duplicates
func (i *Index) Add(relationships ...artifact.Relationship) {
if i.fromID == nil {
i.fromID = map[artifact.ID]*mappedRelationships{}
}
if i.toID == nil {
i.toID = map[artifact.ID]*mappedRelationships{}
}
// store appropriate indexes for stable ordering to minimize ID() calls
for _, r := range relationships {
// prevent duplicates
@ -71,6 +67,7 @@ func (i *Index) Add(relationships ...artifact.Relationship) {
func (i *Index) Remove(id artifact.ID) {
delete(i.fromID, id)
delete(i.toID, id)
for idx := 0; idx < len(i.all); {
if i.all[idx].from == id || i.all[idx].to == id {
i.all = append(i.all[:idx], i.all[idx+1:]...)
@ -80,6 +77,26 @@ func (i *Index) Remove(id artifact.ID) {
}
}
func (i *Index) Replace(ogID artifact.ID, replacement artifact.Identifiable) {
for _, mapped := range fromMappedByID(i.fromID, ogID) {
i.Add(artifact.Relationship{
From: replacement,
To: mapped.relationship.To,
Type: mapped.relationship.Type,
})
}
for _, mapped := range fromMappedByID(i.toID, ogID) {
i.Add(artifact.Relationship{
From: mapped.relationship.From,
To: replacement,
Type: mapped.relationship.Type,
})
}
i.Remove(ogID)
}
// From returns all relationships from the given identifiable, with specified types
func (i *Index) From(identifiable artifact.Identifiable, types ...artifact.RelationshipType) []artifact.Relationship {
return toSortedSlice(fromMapped(i.fromID, identifiable), types)
@ -122,10 +139,17 @@ func (i *Index) All(types ...artifact.RelationshipType) []artifact.Relationship
}
func fromMapped(idMap map[artifact.ID]*mappedRelationships, identifiable artifact.Identifiable) []*sortableRelationship {
if identifiable == nil || idMap == nil {
if identifiable == nil {
return nil
}
mapped := idMap[identifiable.ID()]
return fromMappedByID(idMap, identifiable.ID())
}
func fromMappedByID(idMap map[artifact.ID]*mappedRelationships, id artifact.ID) []*sortableRelationship {
if idMap == nil {
return nil
}
mapped := idMap[id]
if mapped == nil {
return nil
}

View file

@ -327,3 +327,84 @@ func TestRemove(t *testing.T) {
assert.Empty(t, index.From(c3))
assert.Empty(t, index.To(c3))
}
func TestReplace(t *testing.T) {
p1 := pkg.Package{Name: "pkg-1"}
p2 := pkg.Package{Name: "pkg-2"}
p3 := pkg.Package{Name: "pkg-3"}
p4 := pkg.Package{Name: "pkg-4"}
for _, p := range []*pkg.Package{&p1, &p2, &p3, &p4} {
p.SetID()
}
r1 := artifact.Relationship{
From: p1,
To: p2,
Type: artifact.DependencyOfRelationship,
}
r2 := artifact.Relationship{
From: p3,
To: p1,
Type: artifact.DependencyOfRelationship,
}
r3 := artifact.Relationship{
From: p2,
To: p3,
Type: artifact.ContainsRelationship,
}
index := NewIndex(r1, r2, r3)
// replace p1 with p4 in the relationships
index.Replace(p1.ID(), &p4)
expectedRels := []artifact.Relationship{
{
From: p4, // replaced
To: p2,
Type: artifact.DependencyOfRelationship,
},
{
From: p3,
To: p4, // replaced
Type: artifact.DependencyOfRelationship,
},
{
From: p2,
To: p3,
Type: artifact.ContainsRelationship,
},
}
compareRelationships(t, expectedRels, index.All())
}
func compareRelationships(t testing.TB, expected, actual []artifact.Relationship) {
assert.Equal(t, len(expected), len(actual), "number of relationships should match")
for _, e := range expected {
found := false
for _, a := range actual {
if a.From.ID() == e.From.ID() && a.To.ID() == e.To.ID() && a.Type == e.Type {
found = true
break
}
}
assert.True(t, found, "expected relationship not found: %+v", e)
}
}
func TestReplace_NoExistingRelations(t *testing.T) {
p1 := pkg.Package{Name: "pkg-1"}
p2 := pkg.Package{Name: "pkg-2"}
p1.SetID()
p2.SetID()
index := NewIndex()
index.Replace(p1.ID(), &p2)
allRels := index.All()
assert.Len(t, allRels, 0)
}

View file

@ -2,7 +2,6 @@ package task
import (
"context"
"errors"
"fmt"
"sort"
"strings"
@ -112,10 +111,7 @@ func NewPackageTask(cfg CatalogingFactoryConfig, c pkg.Cataloger, tags ...string
pkgs, relationships = finalizePkgCatalogerResults(cfg, resolver, catalogerName, pkgs, relationships)
pkgs, relationships, err = applyCompliance(cfg.ComplianceConfig, catalogerName, pkgs, relationships)
if err != nil {
return err
}
pkgs, relationships = applyCompliance(cfg.ComplianceConfig, pkgs, relationships)
sbom.AddPackages(pkgs...)
sbom.AddRelationships(relationships...)
@ -171,45 +167,47 @@ func finalizePkgCatalogerResults(cfg CatalogingFactoryConfig, resolver file.Path
return pkgs, relationships
}
func applyCompliance(cfg cataloging.ComplianceConfig, catalogerName string, pkgs []pkg.Package, relationships []artifact.Relationship) ([]pkg.Package, []artifact.Relationship, error) {
remainingPkgs, droppedPkgs, err := filterNonCompliantPackages(pkgs, cfg)
var nonCompliantErr cataloging.ErrNonCompliantPackages
if errors.As(err, &nonCompliantErr) {
log.WithFields("cataloger", catalogerName).Errorf(nonCompliantErr.Error())
return nil, nil, err
}
if err != nil {
return nil, nil, fmt.Errorf("unable to filter non-compliant packages: %w", err)
}
type packageReplacement struct {
original artifact.ID
pkg pkg.Package
}
func applyCompliance(cfg cataloging.ComplianceConfig, pkgs []pkg.Package, relationships []artifact.Relationship) ([]pkg.Package, []artifact.Relationship) {
remainingPkgs, droppedPkgs, replacements := filterNonCompliantPackages(pkgs, cfg)
relIdx := relationship.NewIndex(relationships...)
for _, p := range droppedPkgs {
relIdx.Remove(p.ID())
}
return remainingPkgs, relIdx.All(), nil
for _, replacement := range replacements {
relIdx.Replace(replacement.original, replacement.pkg)
}
return remainingPkgs, relIdx.All()
}
func filterNonCompliantPackages(pkgs []pkg.Package, cfg cataloging.ComplianceConfig) ([]pkg.Package, []pkg.Package, error) {
func filterNonCompliantPackages(pkgs []pkg.Package, cfg cataloging.ComplianceConfig) ([]pkg.Package, []pkg.Package, []packageReplacement) {
var remainingPkgs, droppedPkgs []pkg.Package
errNonCompliant := cataloging.NewErrNonCompliantPackages()
var replacements []packageReplacement
for _, p := range pkgs {
if applyComplianceRules(&p, cfg, errNonCompliant) {
keep, replacement := applyComplianceRules(&p, cfg)
if keep {
remainingPkgs = append(remainingPkgs, p)
} else {
droppedPkgs = append(droppedPkgs, p)
}
if replacement != nil {
replacements = append(replacements, *replacement)
}
}
if len(errNonCompliant.NonCompliantPackageLocations) > 0 {
return nil, nil, errNonCompliant
}
return remainingPkgs, droppedPkgs, nil
return remainingPkgs, droppedPkgs, replacements
}
func applyComplianceRules(p *pkg.Package, cfg cataloging.ComplianceConfig, errNonCompliant *cataloging.ErrNonCompliantPackages) bool {
func applyComplianceRules(p *pkg.Package, cfg cataloging.ComplianceConfig) (bool, *packageReplacement) {
var drop bool
var replacement *packageReplacement
applyComplianceRule := func(value, fieldName string, action cataloging.ComplianceAction) bool {
if strings.TrimSpace(value) != "" {
@ -225,11 +223,9 @@ func applyComplianceRules(p *pkg.Package, cfg cataloging.ComplianceConfig, errNo
case cataloging.ComplianceActionDrop:
log.WithFields("pkg", p.String(), "location", loc).Debugf("package with missing %s, dropping", fieldName)
drop = true
case cataloging.ComplianceActionWarn:
log.WithFields("pkg", p.String(), "location", loc).Warnf("package with missing %s, failing", fieldName)
case cataloging.ComplianceActionFail:
errNonCompliant.AddInfo(loc, p.String(), fmt.Sprintf("missing %s", fieldName))
case cataloging.ComplianceActionStub:
log.WithFields("pkg", p.String(), "location", loc).Debugf("package with missing %s, stubbing with default value", fieldName)
return true
case cataloging.ComplianceActionKeep:
@ -238,6 +234,8 @@ func applyComplianceRules(p *pkg.Package, cfg cataloging.ComplianceConfig, errNo
return false
}
ogID := p.ID()
if applyComplianceRule(p.Name, "name", cfg.MissingName) {
p.Name = cataloging.UnknownStubValue
p.SetID()
@ -248,7 +246,15 @@ func applyComplianceRules(p *pkg.Package, cfg cataloging.ComplianceConfig, errNo
p.SetID()
}
return !drop && len(errNonCompliant.NonCompliantPackageLocations) == 0
newID := p.ID()
if newID != ogID {
replacement = &packageReplacement{
original: ogID,
pkg: *p,
}
}
return !drop, replacement
}
func hasAuthoritativeCPE(cpes []cpe.CPE) bool {

View file

@ -84,11 +84,10 @@ func TestApplyCompliance(t *testing.T) {
cfg := cataloging.ComplianceConfig{
MissingName: cataloging.ComplianceActionDrop,
MissingVersion: cataloging.ComplianceActionWarn,
MissingVersion: cataloging.ComplianceActionStub,
}
remainingPkgs, remainingRels, err := applyCompliance(cfg, "test-cataloger", []pkg.Package{p1, p2, p3, p4}, []artifact.Relationship{r1, r2})
require.NoError(t, err)
remainingPkgs, remainingRels := applyCompliance(cfg, []pkg.Package{p1, p2, p3, p4}, []artifact.Relationship{r1, r2})
// p2 should be dropped because it has a missing name, p3 and p4 should pass with a warning for the missing version
assert.Len(t, remainingPkgs, 3) // p1, p3, p4 should remain
@ -106,11 +105,11 @@ func TestFilterNonCompliantPackages(t *testing.T) {
cfg := cataloging.ComplianceConfig{
MissingName: cataloging.ComplianceActionDrop,
MissingVersion: cataloging.ComplianceActionWarn,
MissingVersion: cataloging.ComplianceActionKeep,
}
remainingPkgs, droppedPkgs, err := filterNonCompliantPackages([]pkg.Package{p1, p2, p3}, cfg)
require.NoError(t, err)
remainingPkgs, droppedPkgs, replacement := filterNonCompliantPackages([]pkg.Package{p1, p2, p3}, cfg)
require.Nil(t, replacement)
// p2 should be dropped because it has a missing name
assert.Len(t, remainingPkgs, 2)
@ -121,32 +120,21 @@ func TestFilterNonCompliantPackages(t *testing.T) {
func TestApplyComplianceRules_DropAndStub(t *testing.T) {
p := pkg.Package{Name: "", Version: ""}
p.SetID()
ogID := p.ID()
cfg := cataloging.ComplianceConfig{
MissingName: cataloging.ComplianceActionDrop,
MissingVersion: cataloging.ComplianceActionStub,
}
errNonCompliant := cataloging.NewErrNonCompliantPackages()
isCompliant := applyComplianceRules(&p, cfg, errNonCompliant)
isCompliant, replacement := applyComplianceRules(&p, cfg)
require.NotNil(t, replacement)
assert.Equal(t, packageReplacement{
original: ogID,
pkg: p,
}, *replacement)
// the package should be dropped due to missing name (drop action) and its version should be stubbed
assert.False(t, isCompliant)
assert.Equal(t, cataloging.UnknownStubValue, p.Version)
}
func TestApplyComplianceRules_Fail(t *testing.T) {
p1 := pkg.Package{Name: "", Version: "1.0"} // missing name
p1.SetID()
cfg := cataloging.ComplianceConfig{
MissingName: cataloging.ComplianceActionFail,
}
errNonCompliant := cataloging.NewErrNonCompliantPackages()
isCompliant := applyComplianceRules(&p1, cfg, errNonCompliant)
assert.False(t, isCompliant)
assert.Contains(t, errNonCompliant.NonCompliantPackageLocations, "unknown")
}

View file

@ -1,45 +1,15 @@
package cataloging
import (
"sort"
"strings"
)
const (
ComplianceActionKeep ComplianceAction = "keep"
ComplianceActionWarn ComplianceAction = "warn"
ComplianceActionDrop ComplianceAction = "drop"
ComplianceActionFail ComplianceAction = "fail"
ComplianceActionStub ComplianceAction = "stub"
)
type ErrNonCompliantPackages struct {
NonCompliantPackageLocations map[string][]string
}
func NewErrNonCompliantPackages() *ErrNonCompliantPackages {
return &ErrNonCompliantPackages{
NonCompliantPackageLocations: make(map[string][]string),
}
}
func (e *ErrNonCompliantPackages) AddInfo(location, info, note string) {
e.NonCompliantPackageLocations[location] = append(e.NonCompliantPackageLocations[location], note+": "+info)
}
func (e ErrNonCompliantPackages) Error() string {
var reasons []string
for location, infos := range e.NonCompliantPackageLocations {
for _, info := range infos {
reasons = append(reasons, location+": "+info)
}
}
sort.Strings(reasons)
return "non-compliant packages: " + strings.Join(reasons, "\n")
}
const UnknownStubValue = "UNKNOWN"
type ComplianceAction string
@ -68,14 +38,10 @@ func (c ComplianceAction) Parse() ComplianceAction {
switch strings.ToLower(string(c)) {
case string(ComplianceActionKeep), "include":
return ComplianceActionKeep
case string(ComplianceActionWarn), "warning":
return ComplianceActionWarn
case string(ComplianceActionDrop), "exclude":
return ComplianceActionDrop
case string(ComplianceActionFail), "error":
return ComplianceActionFail
case string(ComplianceActionStub), "replace":
return ComplianceActionStub
}
return ComplianceActionWarn
return ComplianceActionKeep
}