diff --git a/examples/vision/nl_image_search.py b/examples/vision/nl_image_search.py index 942d808656..624030c462 100644 --- a/examples/vision/nl_image_search.py +++ b/examples/vision/nl_image_search.py @@ -42,7 +42,6 @@ from tensorflow.keras import layers import tensorflow_hub as hub import tensorflow_text as text -import tensorflow_addons as tfa import matplotlib.pyplot as plt import matplotlib.image as mpimg from tqdm import tqdm @@ -69,30 +68,44 @@ """ root_dir = "datasets" -annotations_dir = os.path.join(root_dir, "annotations") -images_dir = os.path.join(root_dir, "train2014") +annotations_dir = os.path.join(root_dir, "captions_extracted/annotations") +images_dir = os.path.join(root_dir, "train2014_extracted/train2014") tfrecords_dir = os.path.join(root_dir, "tfrecords") annotation_file = os.path.join(annotations_dir, "captions_train2014.json") -# Download caption annotation files +Download caption annotation files if not os.path.exists(annotations_dir): - annotation_zip = tf.keras.utils.get_file( + annotation_zip = keras.utils.get_file( "captions.zip", cache_dir=os.path.abspath("."), origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip", extract=True, ) - os.remove(annotation_zip) + os.remove(os.path.join(root_dir,"captions.zip")) # Download image files if not os.path.exists(images_dir): - image_zip = tf.keras.utils.get_file( + image_zip = keras.utils.get_file( "train2014.zip", cache_dir=os.path.abspath("."), origin="http://images.cocodataset.org/zips/train2014.zip", extract=True, ) - os.remove(image_zip) + os.remove(os.path.join(root_dir,"train2014.zip")) + +print("Dataset is downloaded and extracted successfully.") + +with open(annotation_file, "r") as f: + annotations = json.load(f)["annotations"] + +image_path_to_caption = collections.defaultdict(list) +for element in annotations: + caption = f"{element['caption'].lower().rstrip('.')}" + image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"]) + image_path_to_caption[image_path].append(caption) + +image_paths = list(image_path_to_caption.keys()) +print(f"Number of images: {len(image_paths)}") print("Dataset is downloaded and extracted successfully.") @@ -228,7 +241,7 @@ def project_embeddings( ): projected_embeddings = layers.Dense(units=projection_dims)(embeddings) for _ in range(num_projection_layers): - x = tf.nn.gelu(projected_embeddings) + x = keras.activations.gelu(projected_embeddings, axis=-1) x = layers.Dense(projection_dims)(x) x = layers.Dropout(dropout_rate)(x) x = layers.Add()([projected_embeddings, x]) @@ -258,7 +271,7 @@ def create_vision_encoder( # Receive the images as inputs. inputs = layers.Input(shape=(299, 299, 3), name="image_input") # Preprocess the input image. - xception_input = tf.keras.applications.xception.preprocess_input(inputs) + xception_input = keras.applications.xception.preprocess_input(inputs) # Generate the embeddings for the images using the xception model. embeddings = xception(xception_input) # Project the embeddings produced by the model. @@ -285,13 +298,10 @@ def create_text_encoder( "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2", name="text_preprocessing", ) - # Load the pre-trained BERT model to be used as the base encoder. - bert = hub.KerasLayer( - "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1", - "bert", - ) - # Set the trainability of the base encoder. - bert.trainable = trainable + # Load the pre-trained BERT model to be used as the base encoder with trainable set to false. + bert = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1", + trainable=False, + name="bert") # Receive the text as inputs. inputs = layers.Input(shape=(), dtype=tf.string, name="text_input") # Preprocess the text. @@ -407,7 +417,7 @@ def test_step(self, features): ) dual_encoder = DualEncoder(text_encoder, vision_encoder, temperature=0.05) dual_encoder.compile( - optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001) + optimizer= keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001) ) """ @@ -419,7 +429,8 @@ def test_step(self, features): print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}") print(f"Number of examples (caption-image pairs): {train_example_count}") print(f"Batch size: {batch_size}") -print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}") +steps_per_epoch = int(np.ceil(train_example_count / batch_size)) +print(f"Steps per epoch: {steps_per_epoch}") train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size) valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size) # Create a learning rate scheduler callback. @@ -427,12 +438,13 @@ def test_step(self, features): monitor="val_loss", factor=0.2, patience=3 ) # Create an early stopping callback. -early_stopping = tf.keras.callbacks.EarlyStopping( +early_stopping = keras.callbacks.EarlyStopping( monitor="val_loss", patience=5, restore_best_weights=True ) history = dual_encoder.fit( train_dataset, epochs=num_epochs, + steps_per_epoch = steps_per_epoch, validation_data=valid_dataset, callbacks=[reduce_lr, early_stopping], )