diff --git a/jar/backup.go b/jar/backup.go new file mode 100644 index 0000000..26bead3 --- /dev/null +++ b/jar/backup.go @@ -0,0 +1,57 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package jar + +import ( + "fmt" + "io" + "os" + "path/filepath" +) + +// backupFile makes copies JARs file in the backup folder +func backupFile(jarPath, jarSource, jarFile string) error { + pathdst := filepath.Join(jarPath, jarFile+".bak") + + _, err := copyJars(jarSource, pathdst) + if err != nil { + return err + } + return nil +} + +func copyJars(src, dst string) (int64, error) { + sourceFileStat, err := os.Stat(src) + if err != nil { + return 0, err + } + + if m := sourceFileStat.Mode(); !m.IsRegular() { + return 0, fmt.Errorf("%s is not a regular file: %s", src, m) + } + + source, err := os.Open(src) + if err != nil { + return 0, err + } + defer source.Close() + + destination, err := os.Create(dst) + if err != nil { + return 0, err + } + defer destination.Close() + nBytes, err := io.Copy(destination, source) + return nBytes, err +} diff --git a/jar/walker.go b/jar/walker.go index d7fee31..c8c74ec 100644 --- a/jar/walker.go +++ b/jar/walker.go @@ -51,6 +51,8 @@ type Walker struct { // Rewrite indicates if the Walker should rewrite JARs in place as it // iterates through the filesystem. Rewrite bool + // Backup indicates if the Walker should create a backups from JARs. + Backup bool // SkipDir, if provided, allows the walker to skip certain directories // as it scans. SkipDir func(path string, de fs.DirEntry) bool @@ -133,6 +135,7 @@ func (w *walker) skipDir(path string, d fs.DirEntry) bool { } func (w *walker) visit(p string, d fs.DirEntry) error { + if d.IsDir() || !d.Type().IsRegular() { return nil } @@ -178,6 +181,13 @@ func (w *walker) visit(p string, d fs.DirEntry) error { return nil } + if w.Backup { + err := backupFile(w.dir, w.filepath(p), p) + if err != nil { + return fmt.Errorf("backup: %v", err) + } + } + dest := w.filepath(p) // Ensure temp file is created in the same directory as the file we want to // rewrite to improve the chances of ending up on the same filesystem. On @@ -215,5 +225,6 @@ func (w *walker) visit(p string, d fs.DirEntry) error { return fmt.Errorf("overwriting %s: %v", p, err) } w.handleRewrite(p, r) + return nil } diff --git a/log4jscanner.go b/log4jscanner.go index fdad047..cbade13 100644 --- a/log4jscanner.go +++ b/log4jscanner.go @@ -40,6 +40,7 @@ Flags: -f, --force Don't skip network and userland filesystems. (smb,nfs,afs,fuse) -w, --rewrite Rewrite vulnerable JARs as they are detected. -v, --verbose Print verbose logs to stderr. + -b, --backup Make a backup of the scanned files `) } @@ -49,8 +50,8 @@ var skipDirs = map[string]bool{ ".git": true, "node_modules": true, ".idea": true, - ".svn": true, - ".p4root": true, + ".svn": true, + ".p4root": true, // TODO(ericchiang): expand } @@ -63,6 +64,8 @@ func main() { v bool force bool f bool + backup bool + b bool toSkip []string ) appendSkip := func(dir string) error { @@ -76,6 +79,8 @@ func main() { flag.BoolVar(&v, "v", false, "") flag.BoolVar(&force, "force", false, "") flag.BoolVar(&f, "f", false, "") + flag.BoolVar(&b, "b", false, "") + flag.BoolVar(&backup, "backup", false, "") flag.Func("s", "", appendSkip) flag.Func("skip", "", appendSkip) flag.Usage = usage @@ -94,6 +99,9 @@ func main() { if w { rewrite = w } + if b { + backup = b + } log.SetFlags(log.LstdFlags | log.Lshortfile) logf := func(format string, v ...interface{}) { if verbose { @@ -103,6 +111,7 @@ func main() { seen := 0 walker := jar.Walker{ Rewrite: rewrite, + Backup: backup, SkipDir: func(path string, d fs.DirEntry) bool { seen++ if seen%5000 == 0 {