Skip to content

Commit bc138e1

Browse files
Refactor utils package to not dump everything unrelated into one file (#352)
1 parent 8aee0e4 commit bc138e1

File tree

10 files changed

+253
-225
lines changed

10 files changed

+253
-225
lines changed

google_guest_agent/addresses.go

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,19 @@ import (
2121
"net"
2222
"reflect"
2323
"runtime"
24+
"slices"
2425
"strings"
2526

2627
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
2728
network "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/network/manager"
2829
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/run"
29-
"github.com/GoogleCloudPlatform/guest-agent/utils"
3030
"github.com/GoogleCloudPlatform/guest-logging-go/logger"
3131
)
3232

3333
var (
34-
addressKey = regKeyBase + `\ForwardedIps`
35-
oldWSFCAddresses string
36-
oldWSFCEnable bool
37-
interfacesEnabled bool
38-
interfaces []net.Interface
34+
addressKey = regKeyBase + `\ForwardedIps`
35+
oldWSFCAddresses string
36+
oldWSFCEnable bool
3937
)
4038

4139
type addressMgr struct{}
@@ -76,7 +74,9 @@ func getForwardsFromRegistry(mac string) ([]string, error) {
7674
oldName := strings.Replace(mac, ":", "", -1)
7775
regFwdIPs, err = readRegMultiString(addressKey, oldName)
7876
if err == nil {
79-
deleteRegKey(addressKey, oldName)
77+
if err = deleteRegKey(addressKey, oldName); err != nil {
78+
logger.Warningf("Failed to delete key: %q, name: %q from registry", addressKey, oldName)
79+
}
8080
}
8181
} else if err != nil {
8282
return nil, err
@@ -86,13 +86,13 @@ func getForwardsFromRegistry(mac string) ([]string, error) {
8686

8787
func compareRoutes(configuredRoutes, desiredRoutes []string) (toAdd, toRm []string) {
8888
for _, desiredRoute := range desiredRoutes {
89-
if !utils.ContainsString(desiredRoute, configuredRoutes) {
89+
if !slices.Contains(configuredRoutes, desiredRoute) {
9090
toAdd = append(toAdd, desiredRoute)
9191
}
9292
}
9393

9494
for _, configuredRoute := range configuredRoutes {
95-
if !utils.ContainsString(configuredRoute, desiredRoutes) {
95+
if !slices.Contains(desiredRoutes, configuredRoute) {
9696
toRm = append(toRm, configuredRoute)
9797
}
9898
}
@@ -205,15 +205,15 @@ func (a *addressMgr) applyWSFCFilter(config *cfg.Sections) {
205205
for idx := range interfaces {
206206
var filteredForwardedIps []string
207207
for _, ip := range interfaces[idx].ForwardedIps {
208-
if !utils.ContainsString(ip, wsfcAddrs) {
208+
if !slices.Contains(wsfcAddrs, ip) {
209209
filteredForwardedIps = append(filteredForwardedIps, ip)
210210
}
211211
}
212212
interfaces[idx].ForwardedIps = filteredForwardedIps
213213

214214
var filteredTargetInstanceIps []string
215215
for _, ip := range interfaces[idx].TargetInstanceIps {
216-
if !utils.ContainsString(ip, wsfcAddrs) {
216+
if !slices.Contains(wsfcAddrs, ip) {
217217
filteredTargetInstanceIps = append(filteredTargetInstanceIps, ip)
218218
}
219219
}
@@ -274,14 +274,8 @@ func (a *addressMgr) Set(ctx context.Context) error {
274274
a.applyWSFCFilter(config)
275275
}
276276

277-
var err error
278-
interfaces, err = net.Interfaces()
279-
if err != nil {
280-
return fmt.Errorf("error populating interfaces: %v", err)
281-
}
282-
283277
// Setup network interfaces.
284-
err = network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces)
278+
err := network.SetupInterfaces(ctx, config, newMetadata.Instance.NetworkInterfaces)
285279
if err != nil {
286280
return fmt.Errorf("failed to setup network interfaces: %v", err)
287281
}
@@ -295,7 +289,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
295289
for _, ni := range newMetadata.Instance.NetworkInterfaces {
296290
iface, err := network.GetInterfaceByMAC(ni.Mac)
297291
if err != nil {
298-
if !utils.ContainsString(ni.Mac, badMAC) {
292+
if !slices.Contains(badMAC, ni.Mac) {
299293
logger.Errorf("Error getting interface: %s", err)
300294
badMAC = append(badMAC, ni.Mac)
301295
}
@@ -328,7 +322,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
328322
}
329323
for _, ip := range configuredIPs {
330324
// Only add to `forwardedIPs` if it is recorded in the registry.
331-
if utils.ContainsString(ip, regFwdIPs) {
325+
if slices.Contains(regFwdIPs, ip) {
332326
forwardedIPs = append(forwardedIPs, ip)
333327
}
334328
}
@@ -371,14 +365,14 @@ func (a *addressMgr) Set(ctx context.Context) error {
371365
var registryEntries []string
372366
for _, ip := range wantIPs {
373367
// If the IP is not in toAdd, add to registry list and continue.
374-
if !utils.ContainsString(ip, toAdd) {
368+
if !slices.Contains(toAdd, ip) {
375369
registryEntries = append(registryEntries, ip)
376370
continue
377371
}
378372
var err error
379373
if runtime.GOOS == "windows" {
380374
// Don't addAddress if this is already configured.
381-
if !utils.ContainsString(ip, configuredIPs) {
375+
if !slices.Contains(configuredIPs, ip) {
382376
err = addAddress(net.ParseIP(ip), net.IPv4Mask(255, 255, 255, 255), uint32(iface.Index))
383377
}
384378
} else {
@@ -394,7 +388,7 @@ func (a *addressMgr) Set(ctx context.Context) error {
394388
for _, ip := range toRm {
395389
var err error
396390
if runtime.GOOS == "windows" {
397-
if !utils.ContainsString(ip, configuredIPs) {
391+
if !slices.Contains(configuredIPs, ip) {
398392
continue
399393
}
400394
err = removeAddress(net.ParseIP(ip), uint32(iface.Index))

google_guest_agent/diagnostics.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"encoding/json"
2020
"reflect"
2121
"runtime"
22+
"slices"
2223
"sync/atomic"
2324

2425
"github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg"
@@ -94,7 +95,7 @@ func (d *diagnosticsMgr) Set(ctx context.Context) error {
9495
}
9596

9697
strEntry := newMetadata.Instance.Attributes.Diagnostics
97-
if utils.ContainsString(strEntry, diagnosticsEntries) {
98+
if slices.Contains(diagnosticsEntries, strEntry) {
9899
return nil
99100
}
100101
diagnosticsEntries = append(diagnosticsEntries, strEntry)

google_guest_agent/non_windows_accounts.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"os/exec"
2424
"path"
2525
"runtime"
26+
"slices"
2627
"sort"
2728
"strconv"
2829
"strings"
@@ -214,7 +215,7 @@ func getUserKeys(mdkeys []string) map[string][]string {
214215
}
215216

216217
if err != nil {
217-
if !utils.ContainsString(trimmedKey, badSSHKeys) {
218+
if !slices.Contains(badSSHKeys, trimmedKey) {
218219
logger.Errorf("%s: %s", err.Error(), trimmedKey)
219220
badSSHKeys = append(badSSHKeys, trimmedKey)
220221
}

google_guest_agent/windows_accounts.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"math/big"
3030
"reflect"
3131
"runtime"
32+
"slices"
3233
"strconv"
3334
"strings"
3435

@@ -422,7 +423,7 @@ func compareAccounts(newKeys metadata.WindowsKeys, oldStrKeys []string) metadata
422423
for _, s := range oldStrKeys {
423424
var key metadata.WindowsKey
424425
if err := json.Unmarshal([]byte(s), &key); err != nil {
425-
if !utils.ContainsString(s, badReg) {
426+
if !slices.Contains(badReg, s) {
426427
logger.Errorf("Bad windows key from registry: %s", err)
427428
badReg = append(badReg, s)
428429
}

google_guest_agent/windows_accounts_test.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ func TestAccountsDisabled(t *testing.T) {
112112
}
113113
}
114114

115-
// rename this with leading disabled because this is a resource
116-
// intensive test. this test takes approx. 141 seconds to complete, next
117-
// longest test is 0.43 seconds.
118-
func disabledTestNewPwd(t *testing.T) {
115+
// Test takes ~43 sec to complete and is resource intensive.
116+
func TestNewPwd(t *testing.T) {
119117
minPasswordLength := 15
120118
maxPasswordLength := 255
121119
var tests = []struct {
@@ -133,31 +131,33 @@ func disabledTestNewPwd(t *testing.T) {
133131
}
134132

135133
for _, tt := range tests {
136-
for i := 0; i < 100000; i++ {
137-
pwd, err := newPwd(tt.passwordLength)
138-
if err != nil {
139-
t.Fatal(err)
140-
}
141-
if len(pwd) != tt.wantPasswordLength {
142-
t.Errorf("Password is not %d characters: len(%s)=%d", tt.wantPasswordLength, pwd, len(pwd))
143-
}
144-
var l, u, n, s int
145-
for _, r := range pwd {
146-
switch {
147-
case unicode.IsLower(r):
148-
l = 1
149-
case unicode.IsUpper(r):
150-
u = 1
151-
case unicode.IsDigit(r):
152-
n = 1
153-
case unicode.IsPunct(r) || unicode.IsSymbol(r):
154-
s = 1
134+
t.Run(tt.name, func(t *testing.T) {
135+
for i := 0; i < 100000; i++ {
136+
pwd, err := newPwd(tt.passwordLength)
137+
if err != nil {
138+
t.Fatal(err)
139+
}
140+
if len(pwd) != tt.wantPasswordLength {
141+
t.Errorf("Password is not %d characters: len(%s)=%d", tt.wantPasswordLength, pwd, len(pwd))
142+
}
143+
var l, u, n, s int
144+
for _, r := range pwd {
145+
switch {
146+
case unicode.IsLower(r):
147+
l = 1
148+
case unicode.IsUpper(r):
149+
u = 1
150+
case unicode.IsDigit(r):
151+
n = 1
152+
case unicode.IsPunct(r) || unicode.IsSymbol(r):
153+
s = 1
154+
}
155+
}
156+
if l+u+n+s < 3 {
157+
t.Errorf("Password does not have at least one character from 3 categories: '%v'", pwd)
155158
}
156159
}
157-
if l+u+n+s < 3 {
158-
t.Errorf("Password does not have at least one character from 3 categories: '%v'", pwd)
159-
}
160-
}
160+
})
161161
}
162162
}
163163

utils/file.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2024 Google LLC
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// OS file util for Google Guest Agent and Google Authorized Keys.
16+
17+
package utils
18+
19+
import (
20+
"fmt"
21+
"io/fs"
22+
"os"
23+
"path/filepath"
24+
)
25+
26+
// SaferWriteFile writes to a temporary file and then replaces the expected output file.
27+
// This prevents other processes from reading partial content while the writer is still writing.
28+
func SaferWriteFile(content []byte, outputFile string, perm fs.FileMode) error {
29+
dir := filepath.Dir(outputFile)
30+
name := filepath.Base(outputFile)
31+
32+
if err := os.MkdirAll(dir, perm); err != nil {
33+
return fmt.Errorf("unable to create required directories %q: %w", dir, err)
34+
}
35+
36+
tmp, err := os.CreateTemp(dir, name+"*")
37+
if err != nil {
38+
return fmt.Errorf("unable to create temporary file under %q: %w", dir, err)
39+
}
40+
41+
if err := os.Chmod(tmp.Name(), perm); err != nil {
42+
return fmt.Errorf("unable to set permissions on temporary file %q: %w", dir, err)
43+
}
44+
45+
if err := tmp.Close(); err != nil {
46+
return fmt.Errorf("failed to close temporary file: %w", err)
47+
}
48+
49+
if err := WriteFile(content, tmp.Name(), perm); err != nil {
50+
return fmt.Errorf("unable to write to a temporary file %q: %w", tmp.Name(), err)
51+
}
52+
53+
return os.Rename(tmp.Name(), outputFile)
54+
}
55+
56+
// CopyFile copies content from src to dst and sets permissions.
57+
func CopyFile(src, dst string, perm fs.FileMode) error {
58+
b, err := os.ReadFile(src)
59+
if err != nil {
60+
return fmt.Errorf("failed to read %q: %w", src, err)
61+
}
62+
63+
if err := WriteFile(b, dst, perm); err != nil {
64+
return fmt.Errorf("failed to write %q: %w", dst, err)
65+
}
66+
67+
if err := os.Chmod(dst, perm); err != nil {
68+
return fmt.Errorf("unable to set permissions on destination file %q: %w", dst, err)
69+
}
70+
71+
return nil
72+
}
73+
74+
// WriteFile creates parent directories if required and writes content to the output file.
75+
func WriteFile(content []byte, outputFile string, perm fs.FileMode) error {
76+
if err := os.MkdirAll(filepath.Dir(outputFile), perm); err != nil {
77+
return fmt.Errorf("unable to create required directories for %q: %w", outputFile, err)
78+
}
79+
return os.WriteFile(outputFile, content, perm)
80+
}

0 commit comments

Comments
 (0)