-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpythonvectordbceph.py
159 lines (131 loc) · 6.39 KB
/
pythonvectordbceph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from flask import Flask, request
#import uuid
import json
import milvus_model
from pymilvus import MilvusClient, DataType, FieldSchema, CollectionSchema
import boto3
import os
import re
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch
# this is need for only when second image embedding function is used
# from transformers import AutoFeatureExtractor, AutoModelForImageClassification
class FeatureExtractor:
def __init__(self, modelname):
# Load the pre-trained model
self.model = timm.create_model(
modelname, pretrained=True, num_classes=0, global_pool="avg"
)
self.model.eval()
# Get the input size required by the model
self.input_size = self.model.default_cfg["input_size"]
config = resolve_data_config({}, model=modelname)
# Get the preprocessing function provided by TIMM for the model
self.preprocess = create_transform(**config)
def __call__(self, imagepath):
# Preprocess the input image
input_image = Image.open(imagepath).convert("RGB") # Convert to RGB if needed
input_image = self.preprocess(input_image)
# Convert the image to a PyTorch tensor and add a batch dimension
input_tensor = input_image.unsqueeze(0)
# Perform inference
with torch.no_grad():
output = self.model(input_tensor)
# Extract the feature vector
feature_vector = output.squeeze().numpy()
return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()
client = MilvusClient(
uri=os.getenv("MILVUS_ENDPOINT")
)
# just assuming http is for the endpoint
endpoint_url = "http://" + os.getenv("BUCKET_HOST") + ":"+ os.getenv("BUCKET_PORT")
s3 = boto3.client(
"s3",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
endpoint_url=endpoint_url,
use_ssl=False,
verify=False,
)
# Store bucket name
bucket_name = os.getenv("BUCKET_NAME")
object_type = os.getenv("OBJECT_TYPE")
app = Flask(__name__)
@app.route('/', methods=['POST'])
def pythonvectordbappceph():
# collection name only supports '_', so change '-' in bucketname to '_'
collection_name = re.sub('-', '_', bucket_name)
app.logger.debug("collection name from the bucket: " + collection_name)
# parse object name from event
# TODO: parse other metadatas and add to the collection
event_data = json.loads(request.data)
object_key = event_data['Records'][0]['s3']['object']['key']
event_type = event_data['Records'][0]['eventName']
app.logger.debug(object_key)
tags = event_data['Records'][0]['s3']['object']['tags']
app.logger.debug("tags : " + str(tags))
# Create collection which includes the id, object url, and embedded vector
if not client.has_collection(collection_name=collection_name):
fields = [
FieldSchema(name='url', dtype=DataType.VARCHAR, max_length=2048, is_primary=True), # VARCHARS need a maximum length, so for this example they are set to 200 characters
FieldSchema(name='embedded_vector', dtype=DataType.FLOAT_VECTOR, dim=int(os.getenv("VECTOR_DIMENSION"))),
FieldSchema(name='tags', dtype=DataType.JSON, nullable=True)
]
schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
client.create_collection(collection_name=collection_name, schema=schema)
index_params = client.prepare_index_params()
index_params.add_index(field_name="embedded_vector", metric_type="L2", index_type="IVF_FLAT", params={"nlist": 16384})
client.create_index(collection_name=collection_name, index_params=index_params)
app.logger.debug("collection " + collection_name + "created")
object_url = endpoint_url+ "/" + bucket_name + "/"+ object_key
client.load_collection(collection_name=collection_name)
# define different functions below code snippet
if event_type == "ObjectRemoved:Delete":
exp = "url == \"" + object_url + "\""
app.logger.debug("starting deletion of "+object_url)
res = client.delete(collection_name=collection_name, filter=exp)
app.logger.debug(res)
return "delete success" # delete success
object_data = s3.get_object(Bucket=bucket_name, Key=object_key)
match object_type:
case "TEXT":
object_content = object_data["Body"].read().decode("utf-8")
objectlist = []
objectlist.append(object_content)
# default embedding function provided by milvus, it has some size limtation for the object
# embedding_fn = milvus_model.DefaultEmbeddingFunction() #dimension 768
embedding_fn = milvus_model.dense.SentenceTransformerEmbeddingFunction(model_name='all-MiniLM-L6-v2',device='cpu') # dimension 384
vectors = embedding_fn.encode_documents(objectlist)
vector = vectors[0]
case "IMAGE":
object_stream = object_data['Body']
# dimesnsion 512
extractor = FeatureExtractor("resnet34")
vector = extractor(object_stream)
# Another embedding function for image object
# case "IMAGE2":
# extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
# model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
# object_stream = object_data['Body']
# object_content = Image.open(object_stream)
# inputs = extractor(images=object_content, return_tensors="pt")
# # dimension 2048
# outputs = model(**inputs)
# # issue : RPC error: [insert_rows], <DataNotMatchException: (code=1, message=The Input data type is inconsistent with defined schema,{vector} field should be a float_vector, but got a {<class 'float'> python
# vector = outputs.squeeze().tolist()
case _:
app.logger.error("Unknown object format")
app.logger.debug(vector)
if len(tags) > 0:
data = [ {"embedded_vector": vector, "url": object_url, "tags": tags} ]
else:
data = [ {"embedded_vector": vector, "url": object_url} ]
res = client.upsert(collection_name=collection_name, data=data)
app.logger.debug(res)
return ''
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=8080)