Skip to content

Commit dc0b2a1

Browse files
Sunandhita BSunandhita B
Sunandhita B
authored and
Sunandhita B
committed
[Ananya/Sunandhita] Add. Implementation for GCP connector
1 parent 65fb840 commit dc0b2a1

File tree

6 files changed

+296
-0
lines changed

6 files changed

+296
-0
lines changed

jb-lib/lib/file_storage/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .azure import AzureAsyncStorage, AzureSyncStorage
22
from .local import LocalAsyncStorage, LocalSyncStorage
3+
from .gcp import GcpAsyncStorage, GcpSyncStorage
34
from .storage import Storage
45
from .handler import StorageHandler
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .gcp_storage import GcpAsyncStorage
2+
from .gcp_sync_storage import GcpSyncStorage
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
from typing import Union, Optional
3+
from datetime import datetime, timedelta, timezone
4+
import logging
5+
from google.cloud import storage
6+
import aiofiles
7+
8+
logger = logging.getLogger("storage")
9+
10+
class GcpAsyncStorage:
11+
__client__ = None
12+
tmp_folder = "/tmp/jb_files"
13+
14+
def __init__(self):
15+
logger.info("Initializing GCP Storage")
16+
17+
project_id = 'indian-legal-bert'
18+
self.__bucket_name__ = 'jugalbandi'
19+
os.environ['GOOGLE_APPLICATION_CREDENTIALS']='/Users/sunandhitab/Downloads/indian-legal-bert-72a5c6f931f1.json'
20+
if not project_id or not self.__bucket_name__:
21+
print(project_id, self.__bucket_name__)
22+
raise ValueError(
23+
"GCPAsyncStorage client not initialized. Missing project_id or bucket_name"
24+
)
25+
26+
self.__client__ = storage.Client(project=project_id)
27+
os.makedirs(self.tmp_folder, exist_ok=True)
28+
29+
async def write_file(
30+
self,
31+
file_path: str,
32+
file_content: Union[str, bytes],
33+
mime_type: Optional[str] = None,
34+
):
35+
if not self.__client__:
36+
raise Exception("GCPAsyncStorage client not initialized")
37+
38+
blob_name = file_path
39+
bucket = self.__client__.bucket(self.__bucket_name__)
40+
blob = bucket.blob(blob_name)
41+
42+
# Determine MIME type if not provided
43+
if mime_type is None:
44+
mime_type = (
45+
"audio/mpeg" if file_path.lower().endswith(".mp3") else "application/octet-stream"
46+
)
47+
48+
# Upload the blob
49+
await asyncio.to_thread(blob.upload_from_string, file_content, content_type=mime_type)
50+
51+
async def _download_file_to_temp_storage(
52+
self, file_path: Union[str, os.PathLike]
53+
) -> Union[str, os.PathLike]:
54+
if not self.__client__:
55+
raise Exception("GCPAsyncStorage client not initialized")
56+
57+
blob_name = file_path
58+
bucket = self.__client__.bucket(self.__bucket_name__)
59+
blob = bucket.blob(blob_name)
60+
61+
tmp_file_path = os.path.join(self.tmp_folder, file_path)
62+
63+
# Create directory if it doesn't exist
64+
os.makedirs(os.path.dirname(tmp_file_path), exist_ok=True)
65+
66+
async with aiofiles.open(tmp_file_path, 'wb') as my_blob:
67+
await asyncio.to_thread(blob.download_to_file, my_blob)
68+
69+
return tmp_file_path
70+
71+
def public_url(self, file_path: str) -> str:
72+
if not self.__client__:
73+
raise Exception("GCPAsyncStorage client not initialized")
74+
75+
blob_name = file_path
76+
bucket = self.__client__.bucket(self.__bucket_name__)
77+
blob = bucket.blob(blob_name)
78+
79+
# Generate a signed URL that expires in 1 day
80+
url = blob.generate_signed_url(
81+
version="v4",
82+
expiration=timedelta(days=1),
83+
method="GET"
84+
)
85+
86+
return url
87+
88+
# Example usage
89+
async def main():
90+
storage = GcpAsyncStorage()
91+
await storage.write_file('example.txt', 'Hello, World!')
92+
tmp_path = await storage._download_file_to_temp_storage('example.txt')
93+
print(f"File downloaded to: {tmp_path}")
94+
url = storage.public_url('example.txt')
95+
print(f"Public URL: {url}")
96+
97+
if __name__ == "__main__":
98+
import asyncio
99+
from dotenv import load_dotenv
100+
load_dotenv()
101+
asyncio.run(main())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
from typing import Union, Optional
3+
from datetime import timedelta
4+
import logging
5+
from google.cloud import storage
6+
7+
logger = logging.getLogger("storage")
8+
9+
class GcpSyncStorage:
10+
__client__ = None
11+
tmp_folder = "/tmp/jb_files"
12+
13+
def __init__(self):
14+
logger.info("Initializing GCP Storage")
15+
16+
project_id = 'indian-legal-bert'
17+
self.__bucket_name__ = 'jugalbandi'
18+
os.environ['GOOGLE_APPLICATION_CREDENTIALS']='/Users/sunandhitab/Downloads/indian-legal-bert-72a5c6f931f1.json'
19+
20+
if not project_id or not self.__bucket_name__:
21+
raise ValueError(
22+
"GCPStorage client not initialized. Missing project_id or bucket_name"
23+
)
24+
25+
self.__client__ = storage.Client(project=project_id)
26+
os.makedirs(self.tmp_folder, exist_ok=True)
27+
28+
def write_file(
29+
self,
30+
file_path: str,
31+
file_content: Union[str, bytes],
32+
mime_type: Optional[str] = None,
33+
):
34+
if not self.__client__:
35+
raise Exception("GCPStorage client not initialized")
36+
37+
blob_name = file_path
38+
bucket = self.__client__.bucket(self.__bucket_name__)
39+
blob = bucket.blob(blob_name)
40+
41+
if mime_type is None:
42+
mime_type = (
43+
"audio/mpeg" if file_path.lower().endswith(".mp3") else "application/octet-stream"
44+
)
45+
46+
# Use synchronous method to upload
47+
blob.upload_from_string(file_content, content_type=mime_type)
48+
49+
def download_file_to_temp_storage(
50+
self, file_path: Union[str, os.PathLike]
51+
) -> Union[str, os.PathLike]:
52+
if not self.__client__:
53+
raise Exception("GCPStorage client not initialized")
54+
55+
blob_name = file_path
56+
bucket = self.__client__.bucket(self.__bucket_name__)
57+
blob = bucket.blob(blob_name)
58+
59+
tmp_file_path = os.path.join(self.tmp_folder, file_path)
60+
os.makedirs(os.path.dirname(tmp_file_path), exist_ok=True)
61+
62+
# Download the file to the temporary location
63+
with open(tmp_file_path, 'wb') as my_blob:
64+
blob.download_to_file(my_blob)
65+
66+
return tmp_file_path
67+
68+
def public_url(self, file_path: str) -> str:
69+
if not self.__client__:
70+
raise Exception("GCPStorage client not initialized")
71+
72+
blob_name = file_path
73+
bucket = self.__client__.bucket(self.__bucket_name__)
74+
blob = bucket.blob(blob_name)
75+
76+
return blob.generate_signed_url(
77+
version="v4",
78+
expiration=timedelta(days=1),
79+
method="GET"
80+
)
81+
82+
# Example usage
83+
def main():
84+
storage = GcpSyncStorage()
85+
storage.write_file('example.txt', 'Hello, World!')
86+
tmp_path = storage.download_file_to_temp_storage('example.txt')
87+
print(f"File downloaded to: {tmp_path}")
88+
url = storage.public_url('example.txt')
89+
print(f"Public URL: {url}")
90+
91+
if __name__ == "__main__":
92+
main()

jb-lib/lib/file_storage/registry.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
from .storage import SyncStorage, AsyncStorage
33
from .local import LocalAsyncStorage, LocalSyncStorage
44
from .azure import AzureAsyncStorage, AzureSyncStorage
5+
from .gcp import GcpAsyncStorage,GcpSyncStorage
56

67
STORAGE_REGISTRY: Dict[str, Type[AsyncStorage]] = {
78
"local": LocalAsyncStorage,
89
"azure": AzureAsyncStorage,
10+
"gcp": GcpAsyncStorage,
911
}
1012

1113
SYNC_STORAGE_REGISTRY: Dict[str, Type[SyncStorage]] = {
1214
"local": LocalSyncStorage,
1315
"azure": AzureSyncStorage,
16+
"gcp": GcpSyncStorage,
1417
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import asyncio
3+
import unittest
4+
import pytest
5+
from unittest.mock import patch, MagicMock, AsyncMock
6+
from unittest import mock
7+
from google.cloud import storage
8+
from google.cloud.storage import Blob
9+
from lib.file_storage.gcp.gcp_storage import GcpAsyncStorage
10+
11+
class TestGCPAsyncStorage(unittest.TestCase):
12+
13+
@mock.patch('google.cloud.storage.Client')
14+
def setUp(self, mock_storage_client):
15+
os.environ['GCP_PROJECT_ID'] = 'test-project-id'
16+
os.environ['GCP_STORAGE_BUCKET'] = 'test-bucket'
17+
18+
# Initialize the GCPAsyncStorage instance
19+
self.storage = GcpAsyncStorage()
20+
21+
# Mock the bucket and blob
22+
self.mock_bucket = mock_storage_client.return_value.bucket.return_value
23+
self.mock_blob = self.mock_bucket.blob.return_value
24+
25+
@patch("google.cloud.storage.Blob")
26+
@pytest.mark.asyncio
27+
async def test_write_file(mock_blob):
28+
with patch("lib.file_storage.gcp.gcp_storage.os.getenv") as mock_getenv, patch("google.cloud.storage.Client") as mock_storage_client:
29+
mock_getenv.side_effect = lambda key: {
30+
"GCP_STORAGE_BUCKET_NAME": "test_bucket",
31+
"GCP_STORAGE_PROJECT": "fake_project",
32+
}.get(key, None)
33+
34+
mock_storage_client_instance = MagicMock()
35+
mock_storage_client.return_value = mock_storage_client_instance
36+
37+
mock_blob_instance = MagicMock(spec=Blob)
38+
mock_blob_instance.upload_from_string = AsyncMock()
39+
mock_storage_client_instance.bucket.return_value.blob.return_value = mock_blob_instance
40+
41+
storage = GcpAsyncStorage()
42+
print("Writing file")
43+
44+
await storage.write_file("test.txt", b"content")
45+
print("File written")
46+
47+
mock_storage_client_instance.bucket.assert_called_once_with("test_bucket")
48+
mock_storage_client_instance.bucket.return_value.blob.assert_called_once_with("test.txt")
49+
mock_blob_instance.upload_from_string.assert_called_once_with(b"content")
50+
51+
52+
@mock.patch('aiofiles.open', new_callable=mock.AsyncMock)
53+
async def test_download_file_to_temp_storage(self, mock_aiofiles_open):
54+
file_path = 'test.txt'
55+
mock_file = mock.MagicMock() # Mock the file object returned by aiofiles
56+
57+
mock_aiofiles_open.return_value.__aenter__.return_value = mock_file
58+
59+
# Set up the download method to do nothing when called
60+
self.mock_blob.download_to_file = mock.AsyncMock()
61+
62+
# Call the async method using asyncio.run
63+
async def run_test():
64+
tmp_path = await self.storage._download_file_to_temp_storage(file_path)
65+
return tmp_path
66+
67+
tmp_path = asyncio.run(run_test())
68+
69+
# Check that the correct temporary file path was returned
70+
self.assertEqual(tmp_path, os.path.join(self.storage.tmp_folder, file_path))
71+
72+
# Assert that the download_to_file method was called with the mock file object
73+
self.mock_blob.download_to_file.assert_called_once_with(mock_file)
74+
75+
76+
async def test_public_url(self):
77+
file_path = 'test.txt'
78+
79+
url = self.storage.public_url(file_path)
80+
81+
# Assert the signed URL generation was called
82+
self.mock_blob.generate_signed_url.assert_called_once_with(
83+
version="v4",
84+
expiration=mock.ANY,
85+
method="GET"
86+
)
87+
88+
# Assert that the URL is generated correctly
89+
self.assertIsInstance(url, str)
90+
91+
@mock.patch('os.makedirs')
92+
def tearDown(self, mock_makedirs):
93+
# Clean up any resources after tests
94+
pass
95+
96+
if __name__ == '__main__':
97+
unittest.main()

0 commit comments

Comments
 (0)