-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: update models of local backend
- Loading branch information
Showing
8 changed files
with
196 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
from . import efficientnet | ||
from . import resnet | ||
from . import x3d | ||
from . import image | ||
from . import video |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import tensorflow as tf | ||
import cv2 | ||
import numpy as np | ||
from transformers import TFAutoModel, AutoConfig | ||
from tensorflow.keras.applications import ( | ||
EfficientNetV2B0, | ||
ResNet50, | ||
preprocess_input as keras_preprocess_input, | ||
) | ||
|
||
|
||
# EmbeddingExtractor Class with model name as a parameter | ||
class EmbeddingExtractor: | ||
def __init__(self, model_name="EfficientNetV2B0"): | ||
self.model_name = model_name | ||
self.model, self.preprocess_fn = self.load_model() | ||
|
||
def load_model(self): | ||
if self.model_name == "EfficientNetV2B0": | ||
base_model = EfficientNetV2B0(include_top=False, pooling="avg") | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
) | ||
elif self.model_name == "ResNet50": | ||
base_model = ResNet50(include_top=False, pooling="avg") | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
) | ||
else: | ||
config = AutoConfig.from_pretrained(self.model_name) | ||
base_model = TFAutoModel.from_pretrained(self.model_name, config=config) | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
) | ||
|
||
return ( | ||
tf.keras.Model(inputs=base_model.input, outputs=base_model.output), | ||
preprocess_fn, | ||
) | ||
|
||
def preprocess_image(self, image): | ||
image = cv2.resize(image, (224, 224)) | ||
image = image.astype("float32") | ||
image = self.preprocess_fn(image) | ||
return image | ||
|
||
def extract_image_embedding(self, image_path): | ||
image = cv2.imread(image_path) | ||
if image is not None: | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
image = self.preprocess_image(image) | ||
embedding = self.model.predict(np.expand_dims(image, axis=0)) | ||
return embedding.squeeze() | ||
return None | ||
|
||
def process_image(self, image_id, file_path): | ||
embedding = self.extract_image_embedding(file_path) | ||
if embedding is not None: | ||
print(f"Extracted embedding for image {image_id}: {embedding.shape}") | ||
else: | ||
print(f"Failed to extract embedding for image {image_id}") | ||
return embedding | ||
|
||
|
||
if __name__ == "__main__": | ||
image_path = "path_to_your_image.jpg" | ||
|
||
# Pass the model name as a parameter | ||
model_name = "microsoft/resnet-50" # Example for Hugging Face model | ||
extractor = EmbeddingExtractor(model_name=model_name) | ||
|
||
# Process the image | ||
extractor.process_image(1, image_path) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.transforms as transforms | ||
import torchvision.models.video as models | ||
|
||
# import tensorflow as tf | ||
import cv2 | ||
import numpy as np | ||
|
||
# from tensorflow.keras.applications import EfficientNetV2B0 # Replaceable with other models | ||
from transformers import TFAutoModel # For Hugging Face models | ||
|
||
# from tensorflow.keras.models import load_model # To load the converted X3D model | ||
|
||
|
||
class EmbeddingExtractor: | ||
def __init__( | ||
self, | ||
model_name="EfficientNetV2B0", | ||
input_shape=(224, 224, 3), | ||
device="cuda" if torch.cuda.is_available() else "cpu", | ||
): | ||
self.model_name = model_name | ||
self.input_shape = input_shape | ||
self.device = device | ||
self.model = self.load_model() | ||
|
||
def load_model(self): | ||
## if self.model_name == "EfficientNetV2B0": | ||
## base_model = EfficientNetV2B0(include_top=False, pooling="avg") | ||
## model = tf.keras.Model(inputs=base_model.input, outputs=base_model.output) | ||
## else if self.model_name == "x3d_model_tf": | ||
## # Load the converted X3D model | ||
## model = load_model(self.model_name) | ||
## else: | ||
## # Load a model from Hugging Face if it's a supported video model | ||
## model = TFAutoModel.from_pretrained(self.model_name) | ||
# Load the pre-trained X3D model from torchvision | ||
if self.model_name == "x3d_m": | ||
model = models.video.x3d_x3d_m(pretrained=True) | ||
elif self.model_name == "x3d_s": | ||
model = models.video.x3d_x3d_s(pretrained=True) | ||
elif self.model_name == "x3d_l": | ||
model = models.video.x3d_x3d_l(pretrained=True) | ||
else: | ||
raise ValueError(f"Unsupported model: {self.model_name}") | ||
|
||
# Remove the final classification layer to extract embeddings | ||
model = nn.Sequential(*list(model.children())[:-1]) | ||
model.to(self.device) | ||
model.eval() | ||
|
||
return model | ||
|
||
def preprocess_video_frames(self, frames): | ||
## Resize each frame to the input shape for the specific model | ||
# frames = [cv2.resize(frame, (self.input_shape[0], self.input_shape[1])) for frame in frames] | ||
# frames = np.array(frames).astype("float32") / 255.0 # Normalize to [0, 1] | ||
# return frames | ||
# Resize and normalize each frame to the input shape for X3D (3D CNN models expect normalization) | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToPILImage(), | ||
transforms.Resize(self.input_shape), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225] | ||
), | ||
] | ||
) | ||
processed_frames = [transform(frame) for frame in frames] | ||
return torch.stack( | ||
processed_frames | ||
) # Stack frames into a tensor (batch of frames) | ||
|
||
def extract_video_embeddings(self, video_path): | ||
cap = cv2.VideoCapture(video_path) | ||
frames = [] | ||
success = True | ||
frame_count = 0 | ||
|
||
# Extract up to 90 frames | ||
while success and frame_count < 90: | ||
success, frame = cap.read() | ||
if success: | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
frames.append(frame) | ||
frame_count += 1 | ||
|
||
cap.release() | ||
|
||
if len(frames) > 0: | ||
## frames = self.preprocess_video_frames(frames) | ||
## embeddings = self.model.predict(np.expand_dims(frames, axis=0)) # Add batch dimension | ||
## return embeddings.squeeze() # Return embedding as numpy array | ||
frames_tensor = ( | ||
self.preprocess_video_frames(frames).unsqueeze(0).to(self.device) | ||
) # Add batch dimension | ||
with torch.no_grad(): | ||
embeddings = self.model(frames_tensor) | ||
return embeddings.squeeze().cpu().numpy() # Convert to numpy array | ||
return None | ||
|
||
def process_video(self, video_id, file_path): | ||
embeddings = self.extract_video_embeddings(file_path) | ||
if embeddings is not None: | ||
print(f"Extracted embeddings for video {video_id}: {embeddings.shape}") | ||
else: | ||
print(f"Failed to extract embeddings for video {video_id}") | ||
return embeddings | ||
|
||
|
||
if __name__ == "__main__": | ||
video_path = "path_to_your_video.mp4" | ||
## extractor = EmbeddingExtractor(model_name="EfficientNetV2B0") # Change model name here | ||
extractor = EmbeddingExtractor( | ||
model_name="x3d_m" | ||
) # You can change to x3d_s or x3d_l | ||
extractor.process_video(1, video_path) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ opencv-python==4.10.0.84 | |
pymilvus==2.4.8 | ||
tensorrt==10.5.0 | ||
nvidia-tensorrt==99.0.0 | ||
huggingface-hub==0.25.2 |