diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8683cb4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*/.vscode/* +media/* +*.pyc +env/ +db.sqlite3 +uwsgi.ini +vilbert_multitask_nginx.conf +static/ \ No newline at end of file diff --git a/demo/__init__.py b/demo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/admin.py b/demo/admin.py new file mode 100644 index 0000000..8d50190 --- /dev/null +++ b/demo/admin.py @@ -0,0 +1,34 @@ +from django.contrib import admin + +from .models import Tasks, QuestionAnswer +# from import_export.admin import ImportExportMixin + + +class ImportExportTimeStampedAdmin(admin.ModelAdmin): + exclude = ("created_at", "modified_at") + + +@admin.register(Tasks) +class TaskAdmin(ImportExportTimeStampedAdmin): + readonly_fields = ("created_at",) + list_display = ( + "unique_id", + "name", + "placeholder", + "example", + "num_of_images", + "description", + ) + + +@admin.register(QuestionAnswer) +class QuestionAnswerAdmin(ImportExportTimeStampedAdmin): + readonly_fields = ("created_at",) + list_display = ( + "task", + "input_text", + "input_images", + "answer_text", + "answer_images", + "socket_id", + ) diff --git a/demo/apps.py b/demo/apps.py new file mode 100644 index 0000000..57920c3 --- /dev/null +++ b/demo/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class DemoConfig(AppConfig): + name = 'demo' diff --git a/demo/constants.py b/demo/constants.py new file mode 100644 index 0000000..c30d13f --- /dev/null +++ b/demo/constants.py @@ -0,0 +1,22 @@ +from django.conf import settings +import os + + +COCO_IMAGES_PATH = os.path.join(settings.MEDIA_ROOT, "test2014") +COCO_IMAGES_URL = os.path.join(settings.MEDIA_URL, "test2014") + +VILBERT_MULTITASK_CONFIG = { + "gpuid": 1, + "image_dir": os.path.join(settings.MEDIA_ROOT, "demo"), +} + + +SLACK_WEBHOOK_URL = "" +BASE_VQA_DIR_PATH = "" +COCO_PARTIAL_IMAGE_NAME = "COCO_test2014_" +RABBITMQ_QUEUE_USERNAME = "" +RABBITMQ_QUEUE_PASSWORD = "" +RABBITMQ_HOST_SERVER = "" +RABBITMQ_HOST_PORT = "" +RABBITMQ_VIRTUAL_HOST = "" +IMAGES_BASE_URL = "" \ No newline at end of file diff --git a/demo/consumers.py b/demo/consumers.py new file mode 100644 index 0000000..f8434e7 --- /dev/null +++ b/demo/consumers.py @@ -0,0 +1,12 @@ +from channels import Group +from .utils import log_to_terminal + +def ws_connect(message): + print("User connnected via Socket") + + +def ws_message(message): + print("Message recieved from client side and the content is ", message.content['text']) + socketid = message.content['text'] + Group(socketid).add(message.reply_channel) + log_to_terminal(socketid, {"info": "User added to the Channel Group"}) \ No newline at end of file diff --git a/demo/migrations/0001_add_models_for_demo.py b/demo/migrations/0001_add_models_for_demo.py new file mode 100644 index 0000000..9257c35 --- /dev/null +++ b/demo/migrations/0001_add_models_for_demo.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.23 on 2020-03-28 01:37 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='QuestionAnswer', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('input_text', models.TextField(blank=True, null=True)), + ('input_images', models.CharField(blank=True, max_length=10000, null=True)), + ('answer_text', models.TextField(blank=True, null=True)), + ('answer_images', models.CharField(blank=True, max_length=10000, null=True)), + ('socket_id', models.CharField(blank=True, max_length=1000, null=True)), + ], + options={ + 'db_table': 'questionanswer', + }, + ), + migrations.CreateModel( + name='Tasks', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('modified_at', models.DateTimeField(auto_now=True)), + ('unique_id', models.PositiveIntegerField(unique=True)), + ('name', models.CharField(blank=True, max_length=1000, null=True)), + ('placeholder', models.TextField(blank=True, null=True)), + ('description', models.TextField(blank=True, null=True)), + ('num_of_images', models.PositiveIntegerField()), + ], + options={ + 'db_table': 'tasks', + }, + ), + migrations.AddField( + model_name='questionanswer', + name='task', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='demo.Tasks'), + ), + ] diff --git a/demo/migrations/0002_attachment.py b/demo/migrations/0002_attachment.py new file mode 100644 index 0000000..93a77ec --- /dev/null +++ b/demo/migrations/0002_attachment.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.10.1 on 2020-03-30 05:18 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('demo', '0001_add_models_for_demo'), + ] + + operations = [ + migrations.CreateModel( + name='Attachment', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('file', models.FileField(upload_to='attachments')), + ], + ), + ] diff --git a/demo/migrations/0003_tasks_example.py b/demo/migrations/0003_tasks_example.py new file mode 100644 index 0000000..bbe50e2 --- /dev/null +++ b/demo/migrations/0003_tasks_example.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.10.1 on 2020-04-05 10:07 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('demo', '0002_attachment'), + ] + + operations = [ + migrations.AddField( + model_name='tasks', + name='example', + field=models.CharField(blank=True, max_length=1000, null=True), + ), + ] diff --git a/demo/migrations/__init__.py b/demo/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/models.py b/demo/models.py new file mode 100644 index 0000000..bbc2ace --- /dev/null +++ b/demo/models.py @@ -0,0 +1,46 @@ +from django.db import models +from django.utils.html import format_html + +class TimeStampedModel(models.Model): + """ + An abstract base class model that provides self-managed `created_at` and + `modified_at` fields. + """ + + created_at = models.DateTimeField(auto_now_add=True) + modified_at = models.DateTimeField(auto_now=True) + + class Meta: + abstract = True + app_label = "demo" + + +class Tasks(TimeStampedModel): + unique_id = models.PositiveIntegerField(unique=True) + name = models.CharField(max_length=1000, blank=True, null=True) + placeholder = models.TextField(null=True, blank=True) + description = models.TextField(null=True, blank=True) + num_of_images = models.PositiveIntegerField() + example = models.CharField(max_length=1000, null=True, blank=True) + + class Meta: + app_label = "demo" + db_table = "tasks" + + +class QuestionAnswer(TimeStampedModel): + task = models.ForeignKey(Tasks) + input_text = models.TextField(null=True, blank=True) + input_images = models.CharField(max_length=10000, null=True, blank=True) + answer_text = models.TextField(null=True, blank=True) + answer_images = models.CharField(max_length=10000, null=True, blank=True) + socket_id = models.CharField(max_length=1000, null=True, blank=True) + class Meta: + app_label = "demo" + db_table = "questionanswer" + + def img_url(self): + return format_html("", self.image) + +class Attachment(models.Model): + file = models.FileField(upload_to='attachments') \ No newline at end of file diff --git a/demo/routers.py b/demo/routers.py new file mode 100644 index 0000000..dcee4e9 --- /dev/null +++ b/demo/routers.py @@ -0,0 +1,7 @@ +from channels.routing import route +from .consumers import ws_message, ws_connect + +channel_routing = [ + route("websocket.receive", ws_message), + route("websocket.connect", ws_connect), +] diff --git a/demo/sender.py b/demo/sender.py new file mode 100644 index 0000000..372937c --- /dev/null +++ b/demo/sender.py @@ -0,0 +1,35 @@ +from django.conf import settings +from .utils import log_to_terminal + +import os +import pika +import sys +import json + + +def vilbert_task(image_path, question, task_id, socket_id): + + connection = pika.BlockingConnection(pika.ConnectionParameters( + host='localhost', + port=5672, + socket_timeout=10000)) + channel = connection.channel() + queue = "vilbert_multitask_queue" + channel.queue_declare(queue=queue, durable=True) + message = { + 'image_path': image_path, + 'question': question, + 'socket_id': socket_id, + "task_id": task_id + } + log_to_terminal(socket_id, {"terminal": "Publishing job to ViLBERT Queue"}) + channel.basic_publish(exchange='', + routing_key=queue, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode = 2, # make message persistent + )) + + print(" [x] Sent %r" % message) + log_to_terminal(socket_id, {"terminal": "Job published successfully"}) + connection.close() \ No newline at end of file diff --git a/demo/templates/index.html b/demo/templates/index.html new file mode 100644 index 0000000..e7cbebe --- /dev/null +++ b/demo/templates/index.html @@ -0,0 +1,25 @@ + +{% load static %} + + + + {% block head %} {%include "head.html"%} {%endblock%} + + + + {%block header%} {%include "header.html"%} {% endblock %} + +
+
+
+ {% block header_content%} {%include "header_content.html"%} {%endblock%} +
+
+
+ {%block demo_images %} {%include "demo_images.html"%} {%endblock%} +
+ {%block terminal %} + {% include "terminal.html" %}{% endblock %} {% block result%}{%include "result.html"%} + {%endblock%} {% block credits %} {%include "credits.html"%} {%endblock%} + + \ No newline at end of file diff --git a/demo/templates/vilbert_multitask/credits.html b/demo/templates/vilbert_multitask/credits.html new file mode 100644 index 0000000..f616168 --- /dev/null +++ b/demo/templates/vilbert_multitask/credits.html @@ -0,0 +1,85 @@ + + + + + + +
+ +
+ Built by @rishabh jain +
+
+
+ +
+ We thank @jiasen lu for his help. +
+
+
+
+ \ No newline at end of file diff --git a/demo/templates/vilbert_multitask/demo_images.html b/demo/templates/vilbert_multitask/demo_images.html new file mode 100644 index 0000000..cff5435 --- /dev/null +++ b/demo/templates/vilbert_multitask/demo_images.html @@ -0,0 +1,147 @@ +{%load static%} + +
+ +
+
+
+ OR +
+
+
+
+
+

Upload your own images

+
+
+
+
+ + +
+
+
+
+
+
+
+
+
+
+
+ +
+
+ +
+
+
\ No newline at end of file diff --git a/demo/templates/vilbert_multitask/form.html b/demo/templates/vilbert_multitask/form.html new file mode 100644 index 0000000..1047a80 --- /dev/null +++ b/demo/templates/vilbert_multitask/form.html @@ -0,0 +1,3 @@ +
+

Drop files here or click to upload

+
\ No newline at end of file diff --git a/demo/templates/vilbert_multitask/head.html b/demo/templates/vilbert_multitask/head.html new file mode 100644 index 0000000..de5c0e0 --- /dev/null +++ b/demo/templates/vilbert_multitask/head.html @@ -0,0 +1,49 @@ +{% load static %} + + + + + + + + + + + + + + + + + + + + + +CloudCV: ViLBERT Multi-Task Demo + + + + + + + + + + + + + + + + + + + diff --git a/demo/templates/vilbert_multitask/header.html b/demo/templates/vilbert_multitask/header.html new file mode 100644 index 0000000..e34b824 --- /dev/null +++ b/demo/templates/vilbert_multitask/header.html @@ -0,0 +1,522 @@ + + + +Fork me on GitHub diff --git a/demo/templates/vilbert_multitask/header_content.html b/demo/templates/vilbert_multitask/header_content.html new file mode 100644 index 0000000..1f1221f --- /dev/null +++ b/demo/templates/vilbert_multitask/header_content.html @@ -0,0 +1,19 @@ + + + \ No newline at end of file diff --git a/demo/templates/vilbert_multitask/result.html b/demo/templates/vilbert_multitask/result.html new file mode 100644 index 0000000..4886271 --- /dev/null +++ b/demo/templates/vilbert_multitask/result.html @@ -0,0 +1,527 @@ + + + + +
+ +
+ +
+ diff --git a/demo/templates/vilbert_multitask/terminal.html b/demo/templates/vilbert_multitask/terminal.html new file mode 100644 index 0000000..e1294f3 --- /dev/null +++ b/demo/templates/vilbert_multitask/terminal.html @@ -0,0 +1,23 @@ +
+

Terminal

+ +
+
    +
+
+
+ +
+

+

How it works

+

+ +
    +
  1. You upload an image.
  2. +
  3. Our servers run the deep-learning based algorithm. +
  4. +
  5. Results and updates are shown in real-time.
  6. +
+


+
+
\ No newline at end of file diff --git a/demo/tests.py b/demo/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/demo/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/demo/urls.py b/demo/urls.py new file mode 100644 index 0000000..76be6a7 --- /dev/null +++ b/demo/urls.py @@ -0,0 +1,11 @@ + +from django.conf.urls import url + +from .views import vilbert_multitask, get_task_details, file_upload + +app_name = "demo" +urlpatterns = [ + url(r'^upload_image/', file_upload, name='upload_image'), + url(r"^get_task_details/(?P[0-9]+)/$", get_task_details, name="get_task_details"), + url(r"^$", vilbert_multitask, name="vilbert_multitask"), +] diff --git a/demo/utils.py b/demo/utils.py new file mode 100644 index 0000000..120fe6b --- /dev/null +++ b/demo/utils.py @@ -0,0 +1,6 @@ +from channels import Group + +import json + +def log_to_terminal(socketid, message): + Group(socketid).send({"text": json.dumps(message)}) \ No newline at end of file diff --git a/demo/views.py b/demo/views.py new file mode 100644 index 0000000..c942473 --- /dev/null +++ b/demo/views.py @@ -0,0 +1,123 @@ +# from rest_framework.decorators import api_view +from django.http import JsonResponse +# from rest_framework.response import Response +# from rest_framework import status +from channels import Group +from django.shortcuts import render +from django.views.decorators.csrf import csrf_exempt +from django.conf import settings + +from .sender import vilbert_task +from .utils import log_to_terminal +from .models import Tasks, QuestionAnswer + +import uuid +import os +import random +import traceback +# import urllib2 +import demo.constants as constants + +COCO_PARTIAL_IMAGE_NAME = constants.COCO_PARTIAL_IMAGE_NAME + +@csrf_exempt +def vilbert_multitask(request, template_name="index.html"): + socketid = uuid.uuid4() + if request.method == "POST": + try: + # Fetch the parameters from client side + socketid = request.POST.get("socket_id") + task_id = request.POST.get("task_id") + input_question = request.POST.get("question").lower() + input_images_list = request.POST.getlist("image_list[]") + print(input_images_list, input_question, task_id) + abs_image_path = [] + # if len(input_images_list) == 1: + # abs_image_path.append(str(os.path.join(settings.BASE_DIR, str(input_images_list[0][1:])))) + # print("Absoulte image path", abs_image_path) + # elif len(input_images_list) == 2: + # for i in range(len(input_images_list)): + # abs_image_path.append(str(os.path.join(settings.BASE_DIR, str(input_images_list[i][1:])))) + # else: + for i in range(len(input_images_list)): + abs_image_path.append(str(os.path.join(settings.BASE_DIR, str(input_images_list[i][1:])))) + # out_dir = os.path.dirname(abs_image_path) + print(socketid, task_id, input_question, abs_image_path) + # Run the Model wrapper + log_to_terminal(socketid, {"terminal": "Starting Vilbert Multitask Job..."}) + vilbert_task(abs_image_path, str(input_question), task_id, socketid) + except Exception as e: + log_to_terminal(socketid, {"terminal": traceback.print_exc()}) + demo_images, images_name = get_demo_images(constants.COCO_IMAGES_PATH) + return render(request, template_name, {"demo_images": demo_images, + "socketid": socketid, + "images_name": images_name}) + + +def get_task_details(request, task_id): + try: + task = Tasks.objects.get(unique_id=task_id) + except Tasks.DoesNotExist: + response_data = { + "error": "Tasks with id {} doesn't exist".format(task_id) + } + return JsonResponse(response_data) + response_data = { + "unique_id": task.unique_id, + "name": task.name, + "placeholder": task.placeholder, + "description": task.description, + "num_of_images": task.num_of_images, + "example": task.example + } + return JsonResponse(response_data) + + +def get_demo_images(demo_images_path): + try: + image_count = 0 + demo_images = [] + while(image_count<6): + random_image = random.choice(os.listdir(demo_images_path)) + if COCO_PARTIAL_IMAGE_NAME in random_image: + demo_images.append(random_image) + image_count += 1 + + demo_images_path = [os.path.join(constants.COCO_IMAGES_URL, x) for x in demo_images] + images_name = [x for x in demo_images] + except Exception as e: + print(traceback.print_exc()) + images = ['img1.jpg', 'img2.jpg', 'img3.jpg', 'img4.jpg', 'img5.jpg', 'img6.jpg',] + demo_images_path = [os.path.join(settings.STATIC_URL, 'images', x) for x in images] + images_name = [x for x in images] + return demo_images_path, images_name + + +def handle_uploaded_file(f, path): + with open(path, 'wb+') as destination: + for chunk in f.chunks(): + destination.write(chunk) + +@csrf_exempt +def file_upload(request): + if request.method == "POST": + images = request.FILES.getlist("files[]") + print("Image", images) + socketid = request.POST.get('socketid') + dir_type = constants.VILBERT_MULTITASK_CONFIG['image_dir'] + + # folder_uuid = uuid.uuid4() + # output_dir = os.path.join(dir_type, str(folder_uuid)) + # if not os.path.exists(output_dir): + # os.makedirs(output_dir) + file_paths = [] + for i in images: + image_uuid = uuid.uuid4() + image_extension = str(i).split(".")[-1] + img_path = os.path.join(dir_type, str(image_uuid)) + "." + image_extension + # handle image upload + handle_uploaded_file(i, img_path) + file_paths.append(img_path.replace(settings.BASE_DIR, "")) + + img_url = img_path.replace(settings.BASE_DIR, "") + return JsonResponse({"file_paths": file_paths}) diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..ffff8d0 --- /dev/null +++ b/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vilbert_multitask.settings") + try: + from django.core.management import execute_from_command_line + except ImportError: + # The above import may fail for some other reason. Ensure that the + # issue is really that Django is missing to avoid masking other + # exceptions on Python 2. + try: + import django + except ImportError: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) + raise + execute_from_command_line(sys.argv) diff --git a/uwsgi_params b/uwsgi_params new file mode 100644 index 0000000..0bdf13b --- /dev/null +++ b/uwsgi_params @@ -0,0 +1,15 @@ +uwsgi_param QUERY_STRING $query_string; +uwsgi_param REQUEST_METHOD $request_method; +uwsgi_param CONTENT_TYPE $content_type; +uwsgi_param CONTENT_LENGTH $content_length; + +uwsgi_param REQUEST_URI $request_uri; +uwsgi_param PATH_INFO $document_uri; +uwsgi_param DOCUMENT_ROOT $document_root; +uwsgi_param SERVER_PROTOCOL $server_protocol; +uwsgi_param HTTPS $https if_not_empty; + +uwsgi_param REMOTE_ADDR $remote_addr; +uwsgi_param REMOTE_PORT $remote_port; +uwsgi_param SERVER_PORT $server_port; +uwsgi_param SERVER_NAME $server_name; \ No newline at end of file diff --git a/vilbert_multitask/__init__.py b/vilbert_multitask/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vilbert_multitask/asgi.py b/vilbert_multitask/asgi.py new file mode 100644 index 0000000..bd8dcdd --- /dev/null +++ b/vilbert_multitask/asgi.py @@ -0,0 +1,6 @@ +import os +from channels.asgi import get_channel_layer + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vilbert_multitask.settings") + +channel_layer = get_channel_layer() \ No newline at end of file diff --git a/vilbert_multitask/settings.py b/vilbert_multitask/settings.py new file mode 100644 index 0000000..c47b6e3 --- /dev/null +++ b/vilbert_multitask/settings.py @@ -0,0 +1,149 @@ +""" +Django settings for vilbert_multitask project. + +Generated by 'django-admin startproject' using Django 1.11.23. + +For more information on this file, see +https://docs.djangoproject.com/en/1.11/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/1.11/ref/settings/ +""" + +import os + +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'v0)e5((^-3_jpp1ghg-tq@!hr_quadcpojvzdvd2yworqajb)z' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + "channels", + "demo", +] + +MIDDLEWARE = [ + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', +] + +ROOT_URLCONF = 'vilbert_multitask.urls' + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [os.path.join(BASE_DIR, "demo", "templates", "vilbert_multitask")], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] + +WSGI_APPLICATION = 'vilbert_multitask.wsgi.application' + + +# Database +# https://docs.djangoproject.com/en/1.11/ref/settings/#databases + +# DATABASES = { +# 'default': { +# 'ENGINE': 'django.db.backends.sqlite3', +# 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), +# } +# } + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.postgresql', + 'NAME': 'vilbert_multitask', + 'USER': 'vilbert', + 'PASSWORD': 'vilbert@123', + 'HOST': '127.0.0.1', + 'PORT': '5432', + } +} + + +# Password validation +# https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/1.11/topics/i18n/ + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = 'UTC' + +USE_I18N = True + +USE_L10N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/1.11/howto/static-files/ +STATIC_URL = '/static/' +STATIC_ROOT = os.path.join(BASE_DIR, 'static') + +MEDIA_ROOT = os.path.join(BASE_DIR, 'media') + +MEDIA_URL= "/media/" + +PIKA_HOST = 'localhost' +CHANNEL_LAYERS = { + "default": { + "BACKEND": "asgi_redis.RedisChannelLayer", + "CONFIG": { + "hosts": [("localhost", 6379)], + "prefix": u"vilbert_multitask_demo" + }, + "ROUTING": "demo.routers.channel_routing", + }, +} \ No newline at end of file diff --git a/vilbert_multitask/urls.py b/vilbert_multitask/urls.py new file mode 100644 index 0000000..f7e4092 --- /dev/null +++ b/vilbert_multitask/urls.py @@ -0,0 +1,31 @@ +"""vilbert_multitask URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/1.11/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.conf.urls import url, include + 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) +""" +from django.conf.urls import url, include +from django.contrib import admin +from django.conf import settings +import django + +urlpatterns = [ + url(r"^admin/", admin.site.urls), + url(r"^", include("demo.urls"), name="demo"), +] + + +if settings.DEBUG: +# # static files (images, css, javascript, etc.) + urlpatterns += [ + url(r'^media/(?P.*)$', django.views.static.serve, {'document_root': settings.MEDIA_ROOT}), + ] \ No newline at end of file diff --git a/vilbert_multitask/wsgi.py b/vilbert_multitask/wsgi.py new file mode 100644 index 0000000..4697a59 --- /dev/null +++ b/vilbert_multitask/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for vilbert_multitask project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vilbert_multitask.settings") + +application = get_wsgi_application() diff --git a/worker.py b/worker.py new file mode 100644 index 0000000..3e6cdae --- /dev/null +++ b/worker.py @@ -0,0 +1,686 @@ +from __future__ import absolute_import +import os +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'vilbert_multitask.settings') + +import django +django.setup() + +from django.conf import settings +from demo.utils import log_to_terminal +from demo.models import QuestionAnswer, Tasks + +import demo.constants as constants +import pika +import time +import yaml +import json +import traceback +import signal +import requests +import atexit + +django.db.close_old_connections() + + +import sys +import os +import torch +import yaml +import cv2 +import argparse +import glob +import pdb +import numpy as np +import PIL +import _pickle as cPickle +import time +import traceback +import uuid + +from PIL import Image +from easydict import EasyDict as edict +from pytorch_transformers.tokenization_bert import BertTokenizer + +from vilbert.datasets import ConceptCapLoaderTrain, ConceptCapLoaderVal +from vilbert.vilbert import VILBertForVLTasks, BertConfig, BertForMultiModalPreTraining +from vilbert.task_utils import LoadDatasetEval + +import matplotlib.pyplot as plt + +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.layers import nms +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.structures.image_list import to_image_list +from maskrcnn_benchmark.utils.model_serialization import load_state_dict +from types import SimpleNamespace + + + +class FeatureExtractor: + MAX_SIZE = 1333 + MIN_SIZE = 800 + + def __init__(self): + self.args = self.get_parser() + self.detection_model = self._build_detection_model() + + def get_parser(self): + parser = SimpleNamespace(model_file= 'save/resnext_models/model_final.pth', + config_file='save/resnext_models/e2e_faster_rcnn_X-152-32x8d-FPN_1x_MLP_2048_FPN_512_train.yaml', + batch_size=1, + num_features=100, + feature_name="fc6", + confidence_threshold=0, + background=False, + partition=0) + return parser + + def _build_detection_model(self): + cfg.merge_from_file(self.args.config_file) + cfg.freeze() + + model = build_detection_model(cfg) + checkpoint = torch.load(self.args.model_file, map_location=torch.device("cpu")) + + load_state_dict(model, checkpoint.pop("model")) + + model.to("cuda") + model.eval() + return model + + def _image_transform(self, path): + img = Image.open(path) + im = np.array(img).astype(np.float32) + # IndexError: too many indices for array, grayscale images + if len(im.shape) < 3: + im = np.repeat(im[:, :, np.newaxis], 3, axis=2) + im = im[:,:,:3] + im = im[:, :, ::-1] + im -= np.array([102.9801, 115.9465, 122.7717]) + im_shape = im.shape + im_height = im_shape[0] + im_width = im_shape[1] + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + + # Scale based on minimum size + im_scale = self.MIN_SIZE / im_size_min + + # Prevent the biggest axis from being more than max_size + # If bigger, scale it down + if np.round(im_scale * im_size_max) > self.MAX_SIZE: + im_scale = self.MAX_SIZE / im_size_max + + im = cv2.resize( + im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR + ) + img = torch.from_numpy(im).permute(2, 0, 1) + + im_info = {"width": im_width, "height": im_height} + + return img, im_scale, im_info + + def _process_feature_extraction( + self, output, im_scales, im_infos, feature_name="fc6", conf_thresh=0 + ): + batch_size = len(output[0]["proposals"]) + n_boxes_per_image = [len(boxes) for boxes in output[0]["proposals"]] + score_list = output[0]["scores"].split(n_boxes_per_image) + score_list = [torch.nn.functional.softmax(x, -1) for x in score_list] + feats = output[0][feature_name].split(n_boxes_per_image) + cur_device = score_list[0].device + + feat_list = [] + info_list = [] + + for i in range(batch_size): + dets = output[0]["proposals"][i].bbox / im_scales[i] + scores = score_list[i] + max_conf = torch.zeros((scores.shape[0])).to(cur_device) + conf_thresh_tensor = torch.full_like(max_conf, conf_thresh) + start_index = 1 + # Column 0 of the scores matrix is for the background class + if self.args.background: + start_index = 0 + for cls_ind in range(start_index, scores.shape[1]): + cls_scores = scores[:, cls_ind] + keep = nms(dets, cls_scores, 0.5) + max_conf[keep] = torch.where( + # Better than max one till now and minimally greater than conf_thresh + (cls_scores[keep] > max_conf[keep]) + & (cls_scores[keep] > conf_thresh_tensor[keep]), + cls_scores[keep], + max_conf[keep], + ) + + sorted_scores, sorted_indices = torch.sort(max_conf, descending=True) + num_boxes = (sorted_scores[: self.args.num_features] != 0).sum() + keep_boxes = sorted_indices[: self.args.num_features] + feat_list.append(feats[i][keep_boxes]) + bbox = output[0]["proposals"][i][keep_boxes].bbox / im_scales[i] + # Predict the class label using the scores + objects = torch.argmax(scores[keep_boxes][start_index:], dim=1) + cls_prob = torch.max(scores[keep_boxes][start_index:], dim=1) + + info_list.append( + { + "bbox": bbox.cpu().numpy(), + "num_boxes": num_boxes.item(), + "objects": objects.cpu().numpy(), + "image_width": im_infos[i]["width"], + "image_height": im_infos[i]["height"], + "cls_prob": scores[keep_boxes].cpu().numpy(), + } + ) + + return feat_list, info_list + + def get_detectron_features(self, image_paths): + img_tensor, im_scales, im_infos = [], [], [] + + for image_path in image_paths: + im, im_scale, im_info = self._image_transform(image_path) + img_tensor.append(im) + im_scales.append(im_scale) + im_infos.append(im_info) + + # Image dimensions should be divisible by 32, to allow convolutions + # in detector to work + current_img_list = to_image_list(img_tensor, size_divisible=32) + current_img_list = current_img_list.to("cuda") + + with torch.no_grad(): + output = self.detection_model(current_img_list) + + feat_list = self._process_feature_extraction( + output, + im_scales, + im_infos, + self.args.feature_name, + self.args.confidence_threshold, + ) + + return feat_list + + def _chunks(self, array, chunk_size): + for i in range(0, len(array), chunk_size): + yield array[i : i + chunk_size] + + def _save_feature(self, file_name, feature, info): + file_base_name = os.path.basename(file_name) + file_base_name = file_base_name.split(".")[0] + info["image_id"] = file_base_name + info["features"] = feature.cpu().numpy() + file_base_name = file_base_name + ".npy" + + np.save(os.path.join(self.args.output_folder, file_base_name), info) + + def extract_features(self, image_path): + + with torch.no_grad(): + features, infos = self.get_detectron_features(image_path) + + return features, infos + + +def tokenize_batch(batch): + return [tokenizer.convert_tokens_to_ids(sent) for sent in batch] + +def untokenize_batch(batch): + return [tokenizer.convert_ids_to_tokens(sent) for sent in batch] + +def detokenize(sent): + """ Roughly detokenizes (mainly undoes wordpiece) """ + new_sent = [] + for i, tok in enumerate(sent): + if tok.startswith("##"): + new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:] + else: + new_sent.append(tok) + return new_sent + +def printer(sent, should_detokenize=True): + if should_detokenize: + sent = detokenize(sent)[1:-1] + print(" ".join(sent)) + + +def prediction(question, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task_tokens, task_id, infos): + + if task_id == "7": + N = len(infos) # define top N results need to return. + else: + N = 3 + + # check the number of image is correct: + if task_id in ["1", "15", "13", "11", "4", "16"]: + assert len(infos) == 1, "task require 1 image" + elif task_id in ["12"]: + assert len(infos) == 2, "task require 2 images" + elif task_id in ["7"]: + assert len(infos) > 1 and len(infos) <= 10, "task require 2-10 images" + else: + raise ValueError('task not valid.') + + + if task_id == "12": + batch_size = 1 + max_num_bbox = features.size(1) + num_options = question.size(1) + question = question.repeat(2, 1) + # question = question.view(batch_size * 2, int(question.size(1) / 2)) + input_mask = input_mask.repeat(2, 1) + # input_mask = input_mask.view(batch_size * 2, int(input_mask.size(1) / 2)) + segment_ids = segment_ids.repeat(2, 1) + # segment_ids = segment_ids.view(batch_size * 2, int(segment_ids.size(1) / 2)) + task_tokens = task_tokens.repeat(2, 1) + + if task_id == "7": + num_image = features.size(0) + max_num_bbox = features.size(1) + question = question.repeat(num_image, 1) + input_mask = input_mask.repeat(num_image, 1) + segment_ids = segment_ids.repeat(num_image, 1) + task_tokens = task_tokens.repeat(num_image, 1) + + with torch.no_grad(): + vil_prediction, vil_prediction_gqa, vil_logit, vil_binary_prediction, vil_tri_prediction, vision_prediction, vision_logit, linguisic_prediction, linguisic_logit, attn_data_list = model( + question, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task_tokens, output_all_attention_masks=True + ) + + # logits = torch.max(vil_prediction, 1)[1].data # argmax + # pdb.set_trace() + + # Load VQA label to answers: + if task_id == "1" or task_id == "2": + prob = torch.softmax(vil_prediction.view(-1), dim=0) + prob_val, prob_idx = torch.sort(prob, 0, True) + + label2ans_path = os.path.join('save', "VQA" ,"cache", "trainval_label2ans.pkl") + vqa_label2ans = cPickle.load(open(label2ans_path, "rb")) + answer = [vqa_label2ans[prob_idx[i].item()] for i in range(N)] + confidence = [prob_val[i].item() for i in range(N)] + output = { + "top3_answer": answer, + "top3_confidence": confidence + } + return output + + # Load GQA label to answers: + if task_id == "15": + label2ans_path = os.path.join('save', "gqa" ,"cache", "trainval_label2ans.pkl") + + prob_gqa = torch.softmax(vil_prediction_gqa.view(-1), dim=0) + prob_val, prob_idx = torch.sort(prob_gqa, 0, True) + gqa_label2ans = cPickle.load(open(label2ans_path, "rb")) + + answer = [gqa_label2ans[prob_idx[i].item()] for i in range(N)] + confidence = [prob_val[i].item() for i in range(N)] + output = { + "top3_answer": answer, + "top3_confidence": confidence + } + return output + + # vil_binary_prediction NLVR2, 0: False 1: True Task 12 + if task_id == "12": + label_map = {0:"False", 1:"True"} + + prob_binary = torch.softmax(vil_binary_prediction.view(-1), dim=0) + prob_val, prob_idx = torch.sort(prob_binary, 0, True) + + answer = [label_map[prob_idx[i].item()] for i in range(2)] + confidence = [prob_val[i].item() for i in range(2)] + output = { + "top3_answer": answer, + "top3_confidence": confidence + } + return output + + # vil_entaliment: + if task_id == "13": + label_map = {0:"contradiction (false)", 1:"neutral", 2:"entailment (true)"} + + # logtis_tri = torch.max(vil_tri_prediction, 1)[1].data + prob_tri = torch.softmax(vil_tri_prediction.view(-1), dim=0) + prob_val, prob_idx = torch.sort(prob_tri, 0, True) + + answer = [label_map[prob_idx[i].item()] for i in range(3)] + confidence = [prob_val[i].item() for i in range(3)] + output = { + "top3_answer": answer, + "top3_confidence": confidence + } + return output + + # vil_logit: + # For image retrieval + if task_id == "7": + sort_val, sort_idx = torch.sort(torch.softmax(vil_logit.view(-1), dim=0), 0, True) + + idx = [sort_idx[i].item() for i in range(N)] + confidence = [sort_val[i].item() for i in range(N)] + output = { + "top3_answer": idx, + "top3_confidence": confidence + } + return output + + # grounding: + # For refer expressions - + if task_id == "11" or task_id == "4" or task_id == "16": + image_w = infos[0]['image_width'] + image_h = infos[0]['image_height'] + prob = torch.softmax(vision_logit.view(-1), dim=0) + grounding_val, grounding_idx = torch.sort(prob, 0, True) + out = [] + for i in range(N): + idx = grounding_idx[i] + val = grounding_val[i] + box = spatials[0][idx][:4].tolist() + y1 = int(box[1] * image_h) + y2 = int(box[3] * image_h) + x1 = int(box[0] * image_w) + x2 = int(box[2] * image_w) + out.append({"y1":y1, "y2":y2, "x1":x1, "x2":x2, 'confidence':val.item()*100}) + return out + +def custom_prediction(query, task, features, infos, task_id): + + # if task is Guesswhat: + if task_id in ["16"]: + tokens_list = [] + dialogs = query.split("q:")[1:] + for dialog in dialogs: + QA_pair = dialog.split("a:") + tokens_list.append("start " + QA_pair[0] + " answer " + QA_pair[1] + " stop ") + + tokens = '' + for token in tokens_list: + tokens = tokens + token + + tokens = tokenizer.encode(query) + tokens = tokenizer.add_special_tokens_single_sentence(tokens) + + segment_ids = [0] * len(tokens) + input_mask = [1] * len(tokens) + + max_length = 37 + if len(tokens) < max_length: + # Note here we pad in front of the sentence + padding = [0] * (max_length - len(tokens)) + tokens = tokens + padding + input_mask += padding + segment_ids += padding + + text = torch.from_numpy(np.array(tokens)).cuda().unsqueeze(0) + input_mask = torch.from_numpy(np.array(input_mask)).cuda().unsqueeze(0) + segment_ids = torch.from_numpy(np.array(segment_ids)).cuda().unsqueeze(0) + task = torch.from_numpy(np.array(task)).cuda().unsqueeze(0) + + num_image = len(infos) + + feature_list = [] + image_location_list = [] + image_mask_list = [] + for i in range(num_image): + image_w = infos[i]['image_width'] + image_h = infos[i]['image_height'] + feature = features[i] + num_boxes = feature.shape[0] + + g_feat = torch.sum(feature, dim=0) / num_boxes + num_boxes = num_boxes + 1 + feature = torch.cat([g_feat.view(1,-1), feature], dim=0) + boxes = infos[i]['bbox'] + image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32) + image_location[:,:4] = boxes + image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h)) + image_location[:,0] = image_location[:,0] / float(image_w) + image_location[:,1] = image_location[:,1] / float(image_h) + image_location[:,2] = image_location[:,2] / float(image_w) + image_location[:,3] = image_location[:,3] / float(image_h) + g_location = np.array([0,0,1,1,1]) + image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0) + image_mask = [1] * (int(num_boxes)) + + feature_list.append(feature) + image_location_list.append(torch.tensor(image_location)) + image_mask_list.append(torch.tensor(image_mask)) + + + features = torch.stack(feature_list, dim=0).float().cuda() + spatials = torch.stack(image_location_list, dim=0).float().cuda() + image_mask = torch.stack(image_mask_list, dim=0).byte().cuda() + co_attention_mask = torch.zeros((num_image, num_boxes, max_length)).cuda() + + answer = prediction(text, features, spatials, segment_ids, input_mask, image_mask, co_attention_mask, task, task_id, infos) + return answer + +# ============================= +# ViLBERT Model Loading Part +# ============================= +def load_vilbert_model(): + global feature_extractor + global tokenizer + global model + + feature_extractor = FeatureExtractor() + + args = SimpleNamespace(from_pretrained= "save/multitask_model/pytorch_model_9.bin", + bert_model="bert-base-uncased", + config_file="config/bert_base_6layer_6conect.json", + max_seq_length=101, + train_batch_size=1, + do_lower_case=True, + predict_feature=False, + seed=42, + num_workers=0, + baseline=False, + img_weight=1, + distributed=False, + objective=1, + visual_target=0, + dynamic_attention=False, + task_specific_tokens=True, + tasks='1', + save_name='', + in_memory=False, + batch_size=1, + local_rank=-1, + split='mteval', + clean_train_sets=True + ) + + config = BertConfig.from_json_file(args.config_file) + with open('./vilbert_tasks.yml', 'r') as f: + task_cfg = edict(yaml.safe_load(f)) + + task_names = [] + for i, task_id in enumerate(args.tasks.split('-')): + task = 'TASK' + task_id + name = task_cfg[task]['name'] + task_names.append(name) + + timeStamp = args.from_pretrained.split('/')[-1] + '-' + args.save_name + config = BertConfig.from_json_file(args.config_file) + default_gpu=True + + if args.predict_feature: + config.v_target_size = 2048 + config.predict_feature = True + else: + config.v_target_size = 1601 + config.predict_feature = False + + if args.task_specific_tokens: + config.task_specific_tokens = True + + if args.dynamic_attention: + config.dynamic_attention = True + + config.visualization = True + num_labels = 3129 + + if args.baseline: + model = BaseBertForVLTasks.from_pretrained( + args.from_pretrained, config=config, num_labels=num_labels, default_gpu=default_gpu + ) + else: + model = VILBertForVLTasks.from_pretrained( + args.from_pretrained, config=config, num_labels=num_labels, default_gpu=default_gpu + ) + + model.eval() + cuda = torch.cuda.is_available() + if cuda: model = model.cuda(0) + tokenizer = BertTokenizer.from_pretrained( + args.bert_model, do_lower_case=args.do_lower_case + ) + + +def callback(ch, method, properties, body): + print("I'm callback") + start = time.time() + body = yaml.safe_load(body) # using yaml instead of json.loads since that unicodes the string in value + # body = {'socket_id': '4109f410-cee7-40d7-aa3b-ff7370f6c537', + # 'question': 'asd', + # 'image_path': ['/home/rjain321/demos/vilbert-miltitask-demo/vilbert_multitask/media/demo/b4497495-d5e9-44e6-8f79-776193ae068d.png', + # '/home/rjain321/demos/vilbert-miltitask-demo/vilbert_multitask/media/demo/6a029649-d40a-4ff4-b4b2-087d6360ef34.png' + # ], + # 'task_id': '7'} + print(" [x] Received %r" % body) + try: + task = Tasks.objects.get(unique_id=int(body["task_id"])) + question_obj = QuestionAnswer.objects.create(task=task, + input_text=body['question'], + input_images=body['image_path'], + socket_id=body['socket_id']) + print("created question answer object") + except: + print(str(traceback.print_exc())) + try: + image_path = body["image_path"] + features, infos = feature_extractor.extract_features(image_path) + query = body["question"] + socket_id = body["socket_id"] + task_id = body["task_id"] + task = [eval(task_id)] + answer = custom_prediction(query, task, features, infos, task_id) + if (task_id == "1" or task_id == "15" or task_id == "2" or task_id == "13"): + top3_answer = answer["top3_answer"] + top3_confidence = answer["top3_confidence"] + top3_list = [] + for i in range(3): + temp = {} + temp["answer"] = top3_answer[i] + temp["confidence"] = round(top3_confidence[i]*100, 2) + top3_list.append(temp) + + result = { + "task_id": task_id, + "result": top3_list + } + print("The task result is", result) + question_obj.answer_text = result + question_obj.save() + + if (task_id == "4" or task_id == "16" or task_id == "11"): + print("The answer is", answer) + image_name_with_bounding_boxes = uuid.uuid4() + + image = image_path[0].split("/") + abs_path = "" + for i in range(len(image)-3): + abs_path += image[i] + abs_path += "/" + color_list = [(0,0,255),(0,255,0),(255,0,0)] + image_name_list = [] + confidence_list = [] + for i, j in zip(answer, color_list): + image_obj = cv2.imread(image_path[0]) + image_name = uuid.uuid4() + # img = Image.open(image_path[0]) + # img = img.crop((i["x1"], i["y1"], i["x2"], i["y2"])) + # image_absolute_path = os.path.join(abs_path, "refer_expressions_task", str(image_name)+".jpg") + # img.save(image_absolute_path, "JPEG") + image_with_bounding_boxes = cv2.rectangle(image_obj, (i["x1"], i["y1"]), (i["x2"], i["y2"]), j, 4) + image_name_list.append(str(image_name)) + confidence_list.append(round(i["confidence"], 2)) + cv2.imwrite(os.path.join(abs_path, "media", "refer_expressions_task", str(image_name)+ ".jpg"), image_with_bounding_boxes) + result = { + "task_id": task_id, + "image_name_list": image_name_list, + "confidence_list": confidence_list + } + question_obj.answer_images = result + question_obj.save() + + if (task_id == "12"): + print(answer) + top3_answer = answer["top3_answer"] + top3_confidence = answer["top3_confidence"] + top3_list = [] + for i in range(2): + temp = {} + temp["answer"] = top3_answer[i] + temp["confidence"] = round(top3_confidence[i]*100, 2) + top3_list.append(temp) + result = { + "task_id": task_id, + "result": top3_list + } + question_obj.answer_text = result + question_obj.save() + + if (task_id == "7"): + top3_answer = answer["top3_answer"] + top3_confidence = answer["top3_confidence"] + image_name_list = [] + confidence_list = [] + for i in range(len(top3_answer)): + print(image_path[top3_answer[i]]) + if "demo" in image_path[0].split("/"): + image_name_list.append("demo/" + os.path.split(image_path[top3_answer[i]])[1].split(".")[0] + "." + str(image_path[0].split("/")[-1].split(".")[1])) + else: + image_name_list.append("test2014/" + os.path.split(image_path[top3_answer[i]])[1].split(".")[0] + "." + str(image_path[0].split("/")[-1].split(".")[1])) + confidence_list.append(round(top3_confidence[i]*100, 2)) + result = { + "task_id": task_id, + "image_name_list": image_name_list, + "confidence_list": confidence_list + } + print("The result is", result) + question_obj.answer_images = result + question_obj.save() + + log_to_terminal(body['socket_id'], {"terminal": json.dumps(result)}) + log_to_terminal(body['socket_id'], {"result": json.dumps(result)}) + log_to_terminal(body['socket_id'], {"terminal": "Completed Task"}) + ch.basic_ack(delivery_tag=method.delivery_tag) + print("Message Deleted") + django.db.close_old_connections() + except Exception as e: + print(traceback.print_exc()) + print(str(e)) + + end = time.time() + print("Time taken is", end - start) + + +def main(): + # Load correponding VQA model into global instance + load_vilbert_model() + connection = pika.BlockingConnection(pika.ConnectionParameters( + host='localhost', + port=5672, + socket_timeout=10000)) + channel = connection.channel() + channel.queue_declare(queue='vilbert_multitask_queue', durable=True) + print('[*] Waiting for messages. To exit press CTRL+C') + # Listen to interface + channel.basic_consume('vilbert_multitask_queue', callback) + channel.start_consuming() + +if __name__ == "__main__": + main()