Skip to content

Commit 2ac180a

Browse files
Philmodpsbang
andauthored
Use a new KAGGLE_GRPC_DATA_PROXY_URL env variable for gRPC proxying (#1337)
http://b/308644984 --------- Co-authored-by: Prathamesh Bang <[email protected]>
1 parent 2da7966 commit 2ac180a

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

Diff for: patches/sitecustomize.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def post_import_logic(module):
8181
if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None:
8282
return
8383
if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or
84-
os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or
85-
os.getenv('KAGGLE_DATA_PROXY_URL') == None):
84+
os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or
85+
(os.getenv('KAGGLE_DATA_PROXY_URL') == None and
86+
os.getenv('KAGGLE_GRPC_DATA_PROXY_URL') == None)):
8687
return
8788

8889
old_configure = module.configure
@@ -101,12 +102,15 @@ def new_configure(*args, **kwargs):
101102
client_options = kwargs['client_options']
102103
else:
103104
client_options = {}
104-
client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL']
105+
105106
if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None:
106-
client_options['api_endpoint'] += '/palmapi'
107107
kwargs['transport'] = 'rest'
108-
elif 'transport' in kwargs and kwargs['transport'] == 'rest':
108+
109+
if 'transport' in kwargs and kwargs['transport'] == 'rest':
110+
client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL']
109111
client_options['api_endpoint'] += '/palmapi'
112+
else:
113+
client_options['api_endpoint'] = os.environ['KAGGLE_GRPC_DATA_PROXY_URL']
110114
kwargs['client_options'] = client_options
111115

112116
old_configure(*args, **kwargs)

Diff for: tests/test_google_generativeai_patch.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_proxy_enabled(self):
3333
env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token)
3434
env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token)
3535
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
36+
env.set("KAGGLE_GRPC_DATA_PROXY_URL", "http://127.0.0.1:50001")
3637
env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True")
3738
server_address = urlparse(self.endpoint)
3839
with env:

Diff for: tests/test_google_generativeai_patch_disabled.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,24 @@ def do_HEAD(self):
1414
self.send_response(200)
1515

1616
def do_GET(self):
17+
print('YO MOD', self.path)
1718
HTTPHandler.called = True
1819
self.send_response(200)
1920
self.send_header("Content-type", "application/json")
2021
self.end_headers()
2122

2223
class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase):
23-
endpoint = "http://127.0.0.1:80"
24+
http_endpoint = "http://127.0.0.1:80"
25+
grpc_endpoint = "http://127.0.0.1:50001"
2426

2527
def test_disabled(self):
2628
env = EnvironmentVarGuard()
2729
env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar")
2830
env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar")
29-
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
31+
env.set("KAGGLE_DATA_PROXY_URL", self.http_endpoint)
32+
env.set("KAGGLE_GRPC_DATA_PROXY_URL", self.grpc_endpoint)
3033
env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True")
31-
server_address = urlparse(self.endpoint)
34+
server_address = urlparse(self.http_endpoint)
3235
with env:
3336
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
3437
threading.Thread(target=httpd.serve_forever).start()
@@ -40,4 +43,4 @@ def test_disabled(self):
4043
except:
4144
pass
4245
httpd.shutdown()
43-
self.assertFalse(HTTPHandler.called)
46+
self.assertFalse(HTTPHandler.called)

0 commit comments

Comments
 (0)