Skip to content

Commit 3598600

Browse files
authored
feat(strm-1922): store and use zed tokens for all calls (#146)
1 parent c60d4b0 commit 3598600

File tree

7 files changed

+139
-78
lines changed

7 files changed

+139
-78
lines changed

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ dist/${target}: ${source_files} Makefile
2424
clean:
2525
rm -f dist/${target}
2626

27+
# Make sure the .env containing all `STRM_TEST_*` variables is present in the ./test directory
28+
# godotenv loads the .env file from that directory when running the tests
2729
test: dist/${target}
2830
go clean -testcache
2931
go test ./test -v

cmd/strm/main.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"strmprivacy/strm/pkg/bootstrap"
1313
"strmprivacy/strm/pkg/common"
1414
"strmprivacy/strm/pkg/context"
15+
"strmprivacy/strm/pkg/user_project"
1516
"strmprivacy/strm/pkg/util"
1617
)
1718

@@ -78,7 +79,7 @@ func rootCmdPreRun(cmd *cobra.Command, args []string) error {
7879
common.ApiAuthHost = util.GetStringAndErr(cmd.Flags(), auth.ApiAuthUrlFlag)
7980

8081
if auth.Auth.LoadLogin() == nil {
81-
bootstrap.SetupServiceClients(auth.Auth.GetToken())
82+
bootstrap.SetupServiceClients(auth.Auth.GetToken(), user_project.GetZedToken())
8283
splitCommand := strings.Split(cmd.CommandPath(), " ")
8384
if splitCommand[1] != "auth" && !(splitCommand[1] == "create" && splitCommand[2] == "project") {
8485
context.ResolveProject(cmd.Flags())

pkg/bootstrap/bootstrap.go

+67-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package bootstrap
22

33
import (
4+
"context"
45
"fmt"
56
log "github.com/sirupsen/logrus"
67
"github.com/spf13/cobra"
78
"github.com/spf13/pflag"
89
"github.com/spf13/viper"
10+
"google.golang.org/grpc"
11+
"google.golang.org/grpc/credentials"
12+
"google.golang.org/grpc/credentials/insecure"
13+
"google.golang.org/grpc/metadata"
914
"strings"
1015
"strmprivacy/strm/pkg/cmd"
1116
"strmprivacy/strm/pkg/common"
@@ -30,6 +35,12 @@ import (
3035
"strmprivacy/strm/pkg/entity/user"
3136
"strmprivacy/strm/pkg/logs"
3237
"strmprivacy/strm/pkg/monitor"
38+
"strmprivacy/strm/pkg/user_project"
39+
)
40+
41+
const (
42+
cliVersionHeader = "strm-cli-version"
43+
zedTokenHeader = "strm-zed-token"
3344
)
3445

3546
/*
@@ -62,8 +73,8 @@ func SetupVerbs(rootCmd *cobra.Command) {
6273
rootCmd.AddCommand(cmd.EvaluateCmd)
6374
}
6475

65-
func SetupServiceClients(accessToken *string) {
66-
clientConnection, ctx := common.SetupGrpc(common.ApiHost, accessToken)
76+
func SetupServiceClients(accessToken *string, zedToken *string) {
77+
clientConnection, ctx := SetupGrpc(common.ApiHost, accessToken, zedToken)
6778

6879
stream.SetupClient(clientConnection, ctx)
6980
kafka_exporter.SetupClient(clientConnection, ctx)
@@ -146,3 +157,57 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) {
146157
}
147158
})
148159
}
160+
161+
func SetupGrpc(host string, token *string, zedToken *string) (*grpc.ClientConn, context.Context) {
162+
163+
var err error
164+
var creds grpc.DialOption
165+
166+
if strings.Contains(host, ":50051") {
167+
creds = grpc.WithTransportCredentials(insecure.NewCredentials())
168+
} else {
169+
creds = grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""))
170+
}
171+
172+
clientConnection, err := grpc.Dial(host, creds, grpc.WithUnaryInterceptor(clientInterceptor))
173+
common.CliExit(err)
174+
175+
var mdMap = map[string]string{cliVersionHeader: common.Version}
176+
177+
if token != nil {
178+
mdMap["authorization"] = "Bearer " + *token
179+
}
180+
if zedToken != nil {
181+
mdMap[zedTokenHeader] = *zedToken
182+
}
183+
184+
return clientConnection, metadata.NewOutgoingContext(context.Background(), metadata.New(mdMap))
185+
}
186+
187+
func clientInterceptor(
188+
ctx context.Context,
189+
method string,
190+
req interface{},
191+
reply interface{},
192+
cc *grpc.ClientConn,
193+
invoker grpc.UnaryInvoker,
194+
opts ...grpc.CallOption,
195+
) error {
196+
zedToken := user_project.GetZedToken()
197+
198+
if zedToken != nil {
199+
ctx = metadata.AppendToOutgoingContext(ctx, zedTokenHeader, *zedToken)
200+
}
201+
202+
var header metadata.MD
203+
opts = append(opts, grpc.Header(&header))
204+
err := invoker(ctx, method, req, reply, cc, opts...)
205+
206+
zedTokenValue := header.Get(zedTokenHeader)
207+
208+
if len(zedTokenValue) > 0 {
209+
user_project.SetZedToken(zedTokenValue[0])
210+
}
211+
212+
return err
213+
}

pkg/common/common.go

-31
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
package common
22

33
import (
4-
"context"
54
"errors"
65
"fmt"
76
log "github.com/sirupsen/logrus"
87
"github.com/spf13/cobra"
9-
"google.golang.org/grpc"
10-
"google.golang.org/grpc/credentials"
11-
"google.golang.org/grpc/credentials/insecure"
12-
"google.golang.org/grpc/metadata"
138
"google.golang.org/grpc/status"
149
"gopkg.in/natefinch/lumberjack.v2"
1510
"os"
1611
"runtime"
17-
"strings"
1812
)
1913

2014
var RootCommandName = "strm"
@@ -24,31 +18,6 @@ var ApiHost string
2418

2519
var ProjectId string
2620

27-
func SetupGrpc(host string, token *string) (*grpc.ClientConn, context.Context) {
28-
29-
var err error
30-
var creds grpc.DialOption
31-
32-
if strings.Contains(host, ":50051") {
33-
creds = grpc.WithTransportCredentials(insecure.NewCredentials())
34-
} else {
35-
creds = grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""))
36-
}
37-
38-
clientConnection, err := grpc.Dial(host, creds)
39-
CliExit(err)
40-
41-
var md metadata.MD
42-
if token != nil {
43-
md = metadata.New(map[string]string{"authorization": "Bearer " + *token, "strm-cli-version": Version})
44-
} else {
45-
md = metadata.New(map[string]string{"strm-cli-version": Version})
46-
}
47-
48-
ctx := metadata.NewOutgoingContext(context.Background(), md)
49-
return clientConnection, ctx
50-
}
51-
5221
func CliExit(err error) {
5322
if err != nil {
5423
_, file, line, _ := runtime.Caller(1)

pkg/context/project.go

+3-25
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,23 @@
11
package context
22

33
import (
4-
"encoding/json"
54
"errors"
65
"fmt"
76
log "github.com/sirupsen/logrus"
87
"github.com/spf13/pflag"
98
"os"
10-
"path"
119
"strmprivacy/strm/pkg/common"
1210
"strmprivacy/strm/pkg/entity/project"
1311
"strmprivacy/strm/pkg/user_project"
1412
)
1513

16-
const activeProjectFilename = "active_projects.json"
17-
1814
// ResolveProject resolves the project to use and makes its ID globally available.
1915
// The value passed through the flag takes precedence, then the value stored in the config dir, and finally
2016
// a fallback to default project.
2117
func ResolveProject(f *pflag.FlagSet) {
22-
23-
activeProjectFilePath := path.Join(common.ConfigPath(), activeProjectFilename)
2418
projectFlagValue, _ := f.GetString(common.ProjectNameFlag)
2519

26-
if _, err := os.Stat(activeProjectFilePath); os.IsNotExist(err) && projectFlagValue == "" {
20+
if _, err := os.Stat(user_project.ActiveProjectFilepath); os.IsNotExist(err) && projectFlagValue == "" {
2721
initActiveProject()
2822
fmt.Println(fmt.Sprintf("Active project was not yet set, has been set to '%v'. You can set a project "+
2923
"with 'strm context project <project-name>'\n", user_project.GetActiveProject()))
@@ -54,7 +48,7 @@ func ResolveProject(f *pflag.FlagSet) {
5448

5549
func SetActiveProject(projectName string) {
5650
if len(project.GetProject(projectName).Projects) != 0 {
57-
saveActiveProject(projectName)
51+
user_project.Projects.SetActiveProject(projectName)
5852
message := "Active project set to: " + projectName
5953
log.Infoln(message)
6054
fmt.Println(message)
@@ -75,21 +69,5 @@ func getFirstProject() string {
7569

7670
func initActiveProject() {
7771
firstProjectName := getFirstProject()
78-
saveActiveProject(firstProjectName)
79-
}
80-
81-
func saveActiveProject(projectName string) {
82-
activeProjectFilepath := path.Join(common.ConfigPath(), activeProjectFilename)
83-
user_project.Projects.SetActiveProject(projectName)
84-
projects, err := json.Marshal(user_project.Projects)
85-
if err != nil {
86-
common.CliExit(err)
87-
}
88-
89-
err = os.WriteFile(
90-
activeProjectFilepath,
91-
projects,
92-
0644,
93-
)
94-
common.CliExit(err)
72+
user_project.Projects.SetActiveProject(firstProjectName)
9573
}

pkg/user_project/user_projects.go

+56-17
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,25 @@ import (
99
"strmprivacy/strm/pkg/common"
1010
)
1111

12-
const ActiveProjectFilename = "active_projects.json"
12+
const activeProjectFilename = "active_projects.json"
1313

14-
var Projects UsersProjects
14+
var ActiveProjectFilepath = path.Join(common.ConfigPath(), activeProjectFilename)
1515

16-
// UsersProjects is the printed json format of the different active projects
16+
var Projects *UsersProjectsContext
17+
18+
// UsersProjectsContext is the printed json format of the different active projects
1719
// per past or currently logged-in user
18-
type UsersProjects struct {
19-
Users []UserProject `json:"users"`
20+
type UsersProjectsContext struct {
21+
Users []UserProjectContext `json:"users"`
2022
}
2123

22-
type UserProject struct {
24+
type UserProjectContext struct {
2325
Email string `json:"email"`
2426
ActiveProject string `json:"active_project"`
27+
ZedToken string `json:"zed_token"`
2528
}
2629

27-
func (projects *UsersProjects) GetCurrentProjectByEmail() string {
30+
func (projects *UsersProjectsContext) GetCurrentProjectByEmail() string {
2831
activeProject := ""
2932
email := GetUserEmail()
3033
for _, user := range projects.Users {
@@ -35,7 +38,7 @@ func (projects *UsersProjects) GetCurrentProjectByEmail() string {
3538
return activeProject
3639
}
3740

38-
func (projects *UsersProjects) SetActiveProject(project string) {
41+
func (projects *UsersProjectsContext) SetActiveProject(project string) {
3942
email := GetUserEmail()
4043
added := false
4144
for index, user := range projects.Users {
@@ -46,11 +49,13 @@ func (projects *UsersProjects) SetActiveProject(project string) {
4649
}
4750

4851
if !added {
49-
projects.Users = append(projects.Users, UserProject{
52+
projects.Users = append(projects.Users, UserProjectContext{
5053
Email: email,
5154
ActiveProject: project,
5255
})
5356
}
57+
58+
storeUserProjectContext()
5459
}
5560

5661
func GetUserEmail() string {
@@ -63,19 +68,53 @@ func GetUserEmail() string {
6368
return auth.Auth.Email
6469
}
6570

66-
func LoadActiveProject() {
67-
activeProjectFilePath := path.Join(common.ConfigPath(), ActiveProjectFilename)
71+
func initializeUsersProjectsContext() {
72+
if Projects == nil {
73+
activeProjectFilePath := path.Join(common.ConfigPath(), activeProjectFilename)
6874

69-
bytes, err := os.ReadFile(activeProjectFilePath)
70-
common.CliExit(err)
71-
activeProjects := UsersProjects{}
72-
_ = json.Unmarshal(bytes, &activeProjects)
73-
Projects = activeProjects
75+
bytes, err := os.ReadFile(activeProjectFilePath)
76+
common.CliExit(err)
77+
activeProjects := UsersProjectsContext{}
78+
_ = json.Unmarshal(bytes, &activeProjects)
79+
Projects = &activeProjects
80+
}
7481
}
7582

7683
func GetActiveProject() string {
77-
LoadActiveProject()
84+
initializeUsersProjectsContext()
7885
activeProject := Projects.GetCurrentProjectByEmail()
7986
log.Infoln("Current active project is: " + activeProject)
8087
return activeProject
8188
}
89+
90+
func SetZedToken(zedToken string) {
91+
initializeUsersProjectsContext()
92+
email := GetUserEmail()
93+
for index, user := range Projects.Users {
94+
// If there is no entry for the user, a zed token will be added the next time, when it is present
95+
if user.Email == email {
96+
(*Projects).Users[index].ZedToken = zedToken
97+
}
98+
}
99+
100+
storeUserProjectContext()
101+
}
102+
103+
func GetZedToken() *string {
104+
initializeUsersProjectsContext()
105+
email := GetUserEmail()
106+
for _, user := range Projects.Users {
107+
if user.Email == email && user.ZedToken != "" {
108+
return &user.ZedToken
109+
}
110+
}
111+
return nil
112+
}
113+
114+
func storeUserProjectContext() {
115+
projects, err := json.Marshal(Projects)
116+
common.CliExit(err)
117+
118+
err = os.WriteFile(ActiveProjectFilepath, projects, 0644)
119+
common.CliExit(err)
120+
}

test/test_util.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ var _testConfig TestConfig
3131

3232
func testConfig() *TestConfig {
3333
if (TestConfig{}) == _testConfig {
34-
_ = godotenv.Load()
34+
err := godotenv.Load()
35+
36+
if err != nil && os.Getenv("GITHUB_ACTION") == "" {
37+
fmt.Fprintf(os.Stderr, "Error loading .env file: %v\n", err)
38+
os.Exit(1)
39+
}
40+
3541
_testConfig = TestConfig{
3642
projectId: os.Getenv("STRM_TEST_PROJECT_ID"),
3743
email: os.Getenv("STRM_TEST_USER_EMAIL"),
@@ -64,7 +70,7 @@ func newConfigDir() string {
6470
_ = os.Setenv("STRM_API_AUTH_URL", "https://accounts.dev.strmprivacy.io")
6571
_ = os.Setenv("STRM_API_HOST", "api.dev.strmprivacy.io:443")
6672
_ = os.Setenv("STRM_HEADLESS", "true")
67-
_ = os.WriteFile(configDir+"/active_project", []byte("default"), 0644)
73+
_ = os.WriteFile(configDir+"/active_projects.json", []byte(fmt.Sprintf(`{"users":[{"email":"%s","active_project":"default"}]}`, os.Getenv("STRM_TEST_USER_EMAIL"))), 0644)
6874
return configDir
6975
}
7076

@@ -212,6 +218,7 @@ func ExecuteAndVerify(t *testing.T, expected proto.Message, args ...string) {
212218
out, err := TryLoad(outputMessage, output)
213219
if err != nil {
214220
fmt.Println("Can't execute", args)
221+
fmt.Fprintln(os.Stderr, err)
215222
t.Fail()
216223
}
217224
assertProtoEquals(t, out, expected)

0 commit comments

Comments
 (0)