Skip to content

Commit d436ba5

Browse files
committed
update vcom test to use the protobuf based API
Signed-off-by: Shahriyar Jalayeri <[email protected]>
1 parent 19f4670 commit d436ba5

File tree

2 files changed

+247
-65
lines changed

2 files changed

+247
-65
lines changed

tests/vcom/testscripts.go

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,220 @@
11
package vcom
22

33
const testScript = `#!/usr/bin/env python3
4-
5-
"""Test vsock communication with the host"""
64
import socket
5+
import sys
6+
import os
7+
8+
# Add the path to the generated protobuf files
9+
sys.path.append('.')
10+
import messages_pb2
11+
12+
VMADDR_CID_LOCAL = socket.VMADDR_CID_HOST # change to 1 for unix.VMADDR_CID_LOCAL
13+
TPM_EK_HANDLE = 0x81000001 # etpm.TpmEKHdl value
14+
VSOCK_PORT = 2000
15+
16+
def vsock_http_request(cid, port, method, path, body=None, headers=None):
17+
"""Make HTTP request over VSOCK"""
18+
if headers is None:
19+
headers = {}
20+
21+
sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
22+
try:
23+
sock.connect((cid, port))
24+
request_lines = [f"{method} {path} HTTP/1.1"]
25+
request_lines.append("Host: vsock")
26+
request_lines.append("Connection: close")
27+
for key, value in headers.items():
28+
request_lines.append(f"{key}: {value}")
29+
if body:
30+
request_lines.append(f"Content-Length: {len(body)}")
31+
request_lines.append("")
32+
http_request = "\r\n".join(request_lines) + "\r\n"
33+
34+
# Send request
35+
sock.send(http_request.encode('utf-8'))
36+
if body:
37+
sock.send(body)
38+
39+
# Read response - first read headers
40+
response_data = b""
41+
headers_end = b"\r\n\r\n"
42+
while headers_end not in response_data:
43+
chunk = sock.recv(1024)
44+
if not chunk:
45+
break
46+
response_data += chunk
47+
headers_part = response_data.split(headers_end)[0].decode('utf-8')
48+
content_length = 0
49+
for line in headers_part.split('\r\n'):
50+
if line.lower().startswith('content-length:'):
51+
content_length = int(line.split(':', 1)[1].strip())
52+
break
53+
body_start = response_data.find(headers_end) + len(headers_end)
54+
body_received = len(response_data) - body_start
55+
56+
while body_received < content_length:
57+
chunk = sock.recv(min(4096, content_length - body_received))
58+
if not chunk:
59+
break
60+
response_data += chunk
61+
body_received += len(chunk)
62+
63+
# Parse HTTP response
64+
response_str = response_data.decode('utf-8', errors='ignore')
65+
if '\r\n\r\n' in response_str:
66+
headers_part, body_part = response_str.split('\r\n\r\n', 1)
67+
else:
68+
headers_part = response_str
69+
body_part = ""
70+
71+
# Parse status line
72+
status_line = headers_part.split('\r\n')[0]
73+
status_code = int(status_line.split()[1])
74+
75+
# Return binary body for protobuf
76+
body_start = response_data.find(b'\r\n\r\n')
77+
if body_start != -1:
78+
binary_body = response_data[body_start + 4:]
79+
else:
80+
binary_body = b""
81+
82+
return status_code, binary_body
83+
84+
finally:
85+
sock.close()
86+
87+
def decode_key_attr(attr):
88+
"""Decode key attributes"""
89+
flags = {
90+
"FlagFixedTPM": 0x00000002,
91+
"FlagStClear": 0x00000004,
92+
"FlagFixedParent": 0x00000010,
93+
"FlagSensitiveDataOrigin": 0x00000020,
94+
"FlagUserWithAuth": 0x00000040,
95+
"FlagAdminWithPolicy": 0x00000080,
96+
"FlagNoDA": 0x00000400,
97+
"FlagRestricted": 0x00010000,
98+
"FlagDecrypt": 0x00020000,
99+
"FlagSign": 0x00040000,
100+
}
101+
102+
attr_list = []
103+
for name, value in flags.items():
104+
if attr & value != 0:
105+
attr_list.append(name)
106+
107+
if not attr_list:
108+
return "NO ATTRIBUTES"
109+
110+
return " | ".join(attr_list)
111+
112+
def test_valid_get_public():
113+
"""Test getting public key from TPM via VSOCK HTTP"""
114+
try:
115+
request = messages_pb2.TpmRequestGetPub()
116+
request.index = TPM_EK_HANDLE
117+
serialized_request = request.SerializeToString()
118+
119+
print(f"Sending TPM GetPub request via VSOCK (CID: {VMADDR_CID_LOCAL}, Port: {VSOCK_PORT})...")
120+
status_code, response_body = vsock_http_request(
121+
cid=VMADDR_CID_LOCAL,
122+
port=VSOCK_PORT,
123+
method="POST",
124+
path="/tpm/getpub",
125+
body=serialized_request
126+
)
127+
128+
# Check status
129+
if status_code != 200:
130+
print(f"Error: expected status 200, got {status_code}")
131+
return False
132+
133+
# Parse protobuf response
134+
tmp_resp = messages_pb2.TpmResponseGetPub()
135+
tmp_resp.ParseFromString(response_body)
136+
137+
# Validate response
138+
if len(tmp_resp.public) == 0:
139+
print("Error: expected non-empty EK, got empty")
140+
return False
141+
142+
# Print results like Go version
143+
print(f"TPM EK: {tmp_resp.public[:16].hex()}...")
144+
print(f"TPM EK Algorithm: {tmp_resp.algorithm}")
145+
print(f"TPM EK Attributes: {decode_key_attr(tmp_resp.attributes)}")
146+
return True
147+
148+
except Exception as e:
149+
print(f"Error occurred: {e}")
150+
import traceback
151+
traceback.print_exc()
152+
return False
153+
154+
def main():
155+
"""Main function"""
156+
print("Testing TPM Get Public Key via VSOCK HTTP...")
157+
success = test_valid_get_public()
158+
159+
if success:
160+
print("\nTest passed!")
161+
sys.exit(0)
162+
else:
163+
print("\nTest failed!")
164+
sys.exit(1)
165+
166+
if __name__ == '__main__':
167+
main()`
168+
169+
const protobufFile = `# Copyright (c) 2025 Zededa, Inc.
170+
# SPDX-License-Identifier: Apache-2.0
171+
# -*- coding: utf-8 -*-
172+
# Generated by the protocol buffer compiler. DO NOT EDIT!
173+
# source: proto/messages.proto
174+
"""Generated protocol buffer code."""
175+
from google.protobuf import descriptor as _descriptor
176+
from google.protobuf import descriptor_pool as _descriptor_pool
177+
from google.protobuf import symbol_database as _symbol_database
178+
from google.protobuf.internal import builder as _builder
179+
# @@protoc_insertion_point(imports)
180+
181+
_sym_db = _symbol_database.Default()
182+
183+
184+
185+
186+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14proto/messages.proto\x12\x04vcom\"!\n\x10TpmRequestGetPub\x12\r\n\x05index\x18\x01 \x01(\r\"J\n\x11TpmResponseGetPub\x12\x0e\n\x06public\x18\x01 \x01(\x0c\x12\x11\n\talgorithm\x18\x02 \x01(\r\x12\x12\n\nattributes\x18\x03 \x01(\r\"-\n\x0eTpmRequestSign\x12\r\n\x05index\x18\x01 \x01(\r\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x91\x01\n\x0fTpmResponseSign\x12\x11\n\talgorithm\x18\x01 \x01(\t\x12\x15\n\rrsa_signature\x18\x02 \x01(\x0c\x12\x10\n\x08rsa_hash\x18\x03 \x01(\t\x12\x17\n\x0f\x65\x63\x63_signature_r\x18\x04 \x01(\x0c\x12\x17\n\x0f\x65\x63\x63_signature_s\x18\x05 \x01(\x0c\x12\x10\n\x08\x65\x63\x63_hash\x18\x06 \x01(\t\"!\n\x10TpmRequestReadNv\x12\r\n\x05index\x18\x01 \x01(\r\"!\n\x11TpmResponseReadNv\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"-\n\x1cTpmRequestActivateCredParams\x12\r\n\x05index\x18\x01 \x01(\r\"N\n\x1dTpmResponseActivateCredParams\x12\n\n\x02\x65k\x18\x01 \x01(\x0c\x12\x0f\n\x07\x61ik_pub\x18\x02 \x01(\x0c\x12\x10\n\x08\x61ik_name\x18\x03 \x01(\x0c\"J\n\x17TpmRequestGeneratedCred\x12\x0c\n\x04\x63red\x18\x01 \x01(\x0c\x12\x0e\n\x06secret\x18\x02 \x01(\x0c\x12\x11\n\taik_index\x18\x03 \x01(\r\"*\n\x18TpmResponseActivatedCred\x12\x0e\n\x06secret\x18\x01 \x01(\x0c\"\"\n\x11TpmRequestCertify\x12\r\n\x05index\x18\x01 \x01(\r\"A\n\x12TpmResponseCertify\x12\x0e\n\x06public\x18\x01 \x01(\x0c\x12\x0b\n\x03sig\x18\x02 \x01(\x0c\x12\x0e\n\x06\x61ttest\x18\x03 \x01(\x0c\x42\x07Z\x05vcom/b\x06proto3')
187+
188+
_globals = globals()
189+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
190+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.messages_pb2', _globals)
191+
if _descriptor._USE_C_DESCRIPTORS == False:
7192
8-
CID = socket.VMADDR_CID_HOST
9-
PORT = 2000
10-
s = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
11-
s.connect((CID, PORT))
12-
s.sendall(b"{\"channel\":2,\"request\":1}")
13-
response = s.recv(1024)
14-
print(response.decode('utf-8'))
15-
s.close()`
193+
DESCRIPTOR._options = None
194+
DESCRIPTOR._serialized_options = b'Z\005vcom/'
195+
_globals['_TPMREQUESTGETPUB']._serialized_start=30
196+
_globals['_TPMREQUESTGETPUB']._serialized_end=63
197+
_globals['_TPMRESPONSEGETPUB']._serialized_start=65
198+
_globals['_TPMRESPONSEGETPUB']._serialized_end=139
199+
_globals['_TPMREQUESTSIGN']._serialized_start=141
200+
_globals['_TPMREQUESTSIGN']._serialized_end=186
201+
_globals['_TPMRESPONSESIGN']._serialized_start=189
202+
_globals['_TPMRESPONSESIGN']._serialized_end=334
203+
_globals['_TPMREQUESTREADNV']._serialized_start=336
204+
_globals['_TPMREQUESTREADNV']._serialized_end=369
205+
_globals['_TPMRESPONSEREADNV']._serialized_start=371
206+
_globals['_TPMRESPONSEREADNV']._serialized_end=404
207+
_globals['_TPMREQUESTACTIVATECREDPARAMS']._serialized_start=406
208+
_globals['_TPMREQUESTACTIVATECREDPARAMS']._serialized_end=451
209+
_globals['_TPMRESPONSEACTIVATECREDPARAMS']._serialized_start=453
210+
_globals['_TPMRESPONSEACTIVATECREDPARAMS']._serialized_end=531
211+
_globals['_TPMREQUESTGENERATEDCRED']._serialized_start=533
212+
_globals['_TPMREQUESTGENERATEDCRED']._serialized_end=607
213+
_globals['_TPMRESPONSEACTIVATEDCRED']._serialized_start=609
214+
_globals['_TPMRESPONSEACTIVATEDCRED']._serialized_end=651
215+
_globals['_TPMREQUESTCERTIFY']._serialized_start=653
216+
_globals['_TPMREQUESTCERTIFY']._serialized_end=687
217+
_globals['_TPMRESPONSECERTIFY']._serialized_start=689
218+
_globals['_TPMRESPONSECERTIFY']._serialized_end=754
219+
# @@protoc_insertion_point(module_scope)
220+
`

tests/vcom/vcom_test.go

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package vcom
22

33
import (
4-
"encoding/json"
54
"fmt"
65
"os"
76
"path"
@@ -11,7 +10,6 @@ import (
1110

1211
tk "github.com/lf-edge/eden/pkg/evetestkit"
1312
"github.com/lf-edge/eden/pkg/utils"
14-
"github.com/lf-edge/eve/pkg/pillar/vcom"
1513
)
1614

1715
var eveNode *tk.EveNode
@@ -44,36 +42,6 @@ func logInfof(format string, args ...interface{}) {
4442
}
4543
}
4644

47-
func getChannel(data []byte) (uint, error) {
48-
var msg vcom.Base
49-
err := json.Unmarshal(data, &msg)
50-
if err != nil {
51-
return 0, err
52-
}
53-
54-
return uint(msg.Channel), nil
55-
}
56-
57-
func decodeTpmResponseEK(data []byte) (*vcom.TpmResponseEk, error) {
58-
tpmRes := new(vcom.TpmResponseEk)
59-
err := json.Unmarshal(data, tpmRes)
60-
if err != nil {
61-
return nil, err
62-
}
63-
64-
return tpmRes, nil
65-
}
66-
67-
func decodeError(data []byte) (*vcom.Error, error) {
68-
errMsg := new(vcom.Error)
69-
err := json.Unmarshal(data, errMsg)
70-
if err != nil {
71-
return nil, err
72-
}
73-
74-
return errMsg, nil
75-
}
76-
7745
func dumpScript(name, content string) error {
7846
return os.WriteFile(name, []byte(content), 0644)
7947
}
@@ -148,7 +116,25 @@ func TestVcomLinkTpmRequestEK(t *testing.T) {
148116
}
149117
logInfof("SSH connection with VM established.")
150118

119+
// make sure python3-protobuf is installed
120+
logInfof("Installing python3-protobuf...")
121+
_, err = eveNode.AppSSHExec(appName, "sudo apt-get update && sudo apt-get install -y python3-protobuf")
122+
if err != nil {
123+
logFatalf("Failed to install python3-protobuf: %v", err)
124+
}
125+
151126
logInfof("Copying test scripts to the vm...")
127+
// dump the protobuf file
128+
err = dumpScript("messages_pb2.py", protobufFile)
129+
if err != nil {
130+
logFatalf("Failed to get path to messages_pb2.py: %v", err)
131+
}
132+
err = eveNode.AppSCPCopy(appName, "messages_pb2.py", "messages_pb2.py")
133+
if err != nil {
134+
logFatalf("Failed to copy messages_pb2.py to the vm: %v", err)
135+
}
136+
137+
// dump the test script
152138
err = dumpScript("testvsock.py", testScript)
153139
if err != nil {
154140
logFatalf("Failed to get path to testvsock.py: %v", err)
@@ -157,36 +143,27 @@ func TestVcomLinkTpmRequestEK(t *testing.T) {
157143
if err != nil {
158144
logFatalf("Failed to copy testvsock.py to the vm: %v", err)
159145
}
146+
147+
// run the test script
148+
logInfof("Testing TPM Get Public Key via vComLink...")
160149
out, err := eveNode.AppSSHExec(appName, "python3 testvsock.py")
161150
if err != nil {
162151
logFatalf("Failed to communicate with host via vsock: %v", err)
163152
}
164153

154+
// check the response
165155
logInfof("Processing vComLink<->VM response...")
166-
channel, err := getChannel([]byte(out))
167-
if err != nil {
168-
logFatalf("Failed to get channel from the output: %v", err)
169-
}
170-
if channel == uint(vcom.ChannelError) {
171-
errMsg, err := decodeError([]byte(out))
172-
if err != nil {
173-
logFatalf("Failed to decode error message: %v", err)
174-
}
175-
logFatalf("Received error message instead of EK: %s", errMsg.Error)
176-
}
177-
if channel != uint(vcom.ChannelTpm) {
178-
logFatalf("Expected channel %d, got %d", vcom.ChannelTpm, channel)
179-
}
180-
181-
logInfof("Received expected TPM response from in the vm")
182-
tpmRes, err := decodeTpmResponseEK([]byte(out))
183-
if err != nil {
184-
logFatalf("Failed to decode tpm response: %v", err)
185-
}
186-
if tpmRes.Ek == "" {
187-
logFatalf("Received an empty EK from the vm")
156+
logInfof("Output: %s", out)
157+
// The script should return something like this, so lets just check for test passed
158+
// Testing TPM Get Public Key via VSOCK HTTP...
159+
// Sending TPM GetPub request via VSOCK (CID: 1, Port: 2000)...
160+
// TPM EK: 0001000b000300b20020837197674484...
161+
// TPM EK Algorithm: 1
162+
// TPM EK Attributes: FlagFixedTPM | FlagFixedParent | FlagSensitiveDataOrigin | FlagAdminWithPolicy | FlagRestricted | FlagDecrypt
163+
// Test passed!
164+
if !strings.Contains(string(out), "passed") {
165+
logFatalf("vComLink<->VM communication failed, output: %s", out)
188166
}
189-
logInfof("Received expected EK in the TPM response")
190167

191168
logInfof("TestvComLinkTpmRequestEK passed")
192169
}

0 commit comments

Comments
 (0)