Skip to content

Commit

Permalink
[FIX] auth layer
Browse files Browse the repository at this point in the history
  • Loading branch information
archetipo committed Mar 21, 2024
1 parent 4867ce2 commit 5579477
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 76 deletions.
15 changes: 7 additions & 8 deletions backend/ozon/core/OzonRawMiddleware.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
# Copyright INRIM (https://www.inrim.eu)
# See LICENSE file for full licensing details.

from typing import Optional, Sequence, Union, Any
from fastapi.responses import RedirectResponse, JSONResponse
from typing import Optional, Union, Any

from starlette.requests import HTTPConnection, Request
from starlette.types import Message, Receive, Scope, Send
from fastapi import FastAPI
from .Ozon import Ozon
from .SessionMain import SessionMain, SessionBase
import logging
import bson

logger = logging.getLogger(__name__)



class OzonRawMiddleware:
def __init__(
self, app: FastAPI, settings, pwd_context: Optional[Any] = None
self, app: FastAPI, settings, pwd_context: Optional[Any] = None
) -> None:
self.app = app
self.pwd_context = pwd_context
Expand All @@ -26,7 +25,7 @@ def __init__(

@staticmethod
def get_request_object(
scope, receive, send
scope, receive, send
) -> Union[Request, HTTPConnection]:
# here we instantiate HTTPConnection instead of a Request object
# because only headers are needed so that's sufficient.
Expand All @@ -35,7 +34,7 @@ def get_request_object(
return Request(scope, receive, send)

async def __call__(
self, scope: Scope, receive: Receive, send: Send
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
Expand All @@ -62,8 +61,8 @@ async def send_wrapper(message: Message) -> None:
f"object: {request.scope['ozon']} , params: {request.query_params}, headers{request.headers}"
)
# self.session = await self.init_request(request)

if not session or session is None:
logger.info(f'Is public {request.scope["ozon"].auth_service.is_public_endpoint()}')
if not session or session is None and not request.scope["ozon"].auth_service.is_public_endpoint():
response = request.scope["ozon"].auth_service.login_page()
await response(scope, receive, send)
else:
Expand Down
57 changes: 23 additions & 34 deletions backend/ozon/core/ServiceAuth.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
# Copyright INRIM (https://www.inrim.eu)
# See LICENSE file for full licensing details.
import sys
import os
from os import listdir
from os.path import isfile, join
from fastapi.responses import RedirectResponse, JSONResponse
import ujson
import uuid

# from ozon.settings import get_settings
from .database.mongo_core import *
from collections import OrderedDict
from pathlib import Path
from fastapi import Request
from .SessionMain import SessionMain
from fastapi.responses import JSONResponse

from .BaseClass import PluginBase
from .ModelData import ModelData
from .BaseClass import BaseClass, PluginBase
from pydantic import ValidationError
from .SessionMain import SessionMain
# from ozon.settings import get_settings
from .database.mongo_core import *

import logging
import pymongo
import requests
import httpx
import uuid
from cryptography.fernet import Fernet

logger = logging.getLogger(__name__)

Expand All @@ -37,13 +26,13 @@ def __init_subclass__(cls, **kwargs):
class ServiceAuthBase(ServiceAuth):
@classmethod
def create(
cls,
settings=None,
public_endpoint="",
parent=None,
request=None,
pwd_context=None,
req_id="",
cls,
settings=None,
public_endpoint="",
parent=None,
request=None,
pwd_context=None,
req_id="",
):
self = ServiceAuthBase()
self.init(
Expand All @@ -52,13 +41,13 @@ def create(
return self

def init(
self,
settings=None,
public_endpoint="",
parent=None,
request=None,
pwd_context=None,
req_id="",
self,
settings=None,
public_endpoint="",
parent=None,
request=None,
pwd_context=None,
req_id="",
):
self.session = None
self.app_code = parent.app_code
Expand Down Expand Up @@ -155,7 +144,7 @@ async def check_default_token_header(self):
logger.debug(f"ws_request {apitoken}")
self.ws_request = True
self.token = apitoken
logger.info(f" Is WS {self.ws_request} with token {self.token}")
logger.debug(f" Is WS {self.ws_request} with token {self.token}")

async def check_session(self):
logger.info("check_session")
Expand Down Expand Up @@ -190,7 +179,7 @@ async def init_session(self):
self.session = await self.session_service.find_session_by_token()
if not self.session and self.is_public_endpoint:
self.session = await self.create_session_public_user()
if self.session.expire_datetime < datetime.now():
if self.session and self.session.expire_datetime < datetime.now():
self.session.active = False
await self.mdata.save_record(self.session)
self.session = None
Expand Down
2 changes: 2 additions & 0 deletions web-client/appinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from core.cache.cache import get_cache

from client_api import client_api
from auth_api import auth_api
from process_api import process_api
from requisition_api import requisition_api
import asyncio
Expand Down Expand Up @@ -90,6 +91,7 @@
),
name="static",
)
app.mount("/auth", auth_api)
app.mount("/client", client_api)
app.mount("/process", process_api)
app.mount("/requisition", requisition_api)
Expand Down
29 changes: 29 additions & 0 deletions web-client/auth_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright INRIM (https://www.inrim.eu)
# See LICENSE file for full licensing details.
import logging
from typing import Optional, Union

import ujson
from fastapi import (
FastAPI,
Request,
Header,
)
from fastapi.responses import (
JSONResponse,
)

from core.ExportService import ExportService
from core.Gateway import Gateway
from settings import get_settings, templates

logger = logging.getLogger(__name__)

auth_api = FastAPI(
title=f"{get_settings().module_name} Client",
description=get_settings().description,
version=get_settings().version,
openapi_url="/openapi.json",
docs_url="/docs",
redoc_url="/redoc",
)
11 changes: 7 additions & 4 deletions web-client/core/ClientMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ async def __call__(
# await send(message)

await request.scope["interceptor"].before_request(request)
response = await self.call_next(request)
response = await request.scope["interceptor"].before_response(
request, response
)
if request.scope.get("security_next_call"):
response = await request.scope.get("security_next_call")(request)
else:
response = await self.call_next(request)
response = await request.scope["interceptor"].before_response(
request, response
)

await response(scope, receive, send)

Expand Down
67 changes: 38 additions & 29 deletions web-client/core/Gateway.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# Copyright INRIM (https://www.inrim.eu)
# See LICENSE file for full licensing details.
import copy
import sys
from typing import Optional
from fastapi import FastAPI, Request, Header, HTTPException, Depends, Form
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
from .ContentService import ContentService
from .main.base.base_class import BaseClass, PluginBase
from .main.base.utils_for_service import requote_uri
from starlette.status import HTTP_302_FOUND, HTTP_303_SEE_OTHER
from fastapi.concurrency import run_in_threadpool
from starlette.datastructures import QueryParams
import httpx
import logging
import ujson
import re

import httpx
import ujson
from core.cache.cache import get_cache
from fastapi import Request
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
from starlette.datastructures import QueryParams
from starlette.status import HTTP_303_SEE_OTHER

from .ContentService import ContentService
from .main.base.base_class import PluginBase
from .main.base.utils_for_service import requote_uri

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,7 +53,6 @@ def query_params(cls, url) -> QueryParams:
qstring = url.split("?")[1] if "?" in url else ""
return QueryParams(qstring)


def clean_form(self, form_data):
# logger.info(f"{form_data}")
dat = {}
Expand Down Expand Up @@ -106,6 +104,12 @@ def init_headers_and_token(self):
elif self.request.headers.get("apitoken"):
self.token = self.request.headers.get("apitoken")
self.is_api = True
else:
self.headers.update(
{
"authtoken": self.token,
}
)

# TODO move in shibboleth Gateway
if "x-remote-user" not in self.request.headers:
Expand Down Expand Up @@ -152,7 +156,7 @@ async def compute_datagrid_rows(self, key, model_name, rec_name=""):
return res

async def compute_datagrid_add_row(
self, key, num_rows, model_name, rec_name="", data={}
self, key, num_rows, model_name, rec_name="", data={}
):
logger.info("compute_datagrid_add_row")
await self.get_session()
Expand All @@ -177,7 +181,7 @@ async def before_submit(self, data, is_create=False):
return data.copy()

async def middleware_server_post_action(
self, content_service, submitted_data
self, content_service, submitted_data
) -> dict:
"""
This middleware method is triggered form Gateway.server_post_action method
Expand Down Expand Up @@ -206,17 +210,17 @@ async def middleware_server_post_action(

if "rec_name" in submitted_data:
allowed = (
self.name_allowed.match(submitted_data.get("rec_name"))
or False
self.name_allowed.match(submitted_data.get("rec_name"))
or False
)
if not allowed:
logger.error(f"name {submitted_data.get('rec_name')}")

err = {
"status": "error",
"message": f"Errore nel campo name "
f"{submitted_data.get('rec_name')} "
f"caratteri non consentiti",
f"{submitted_data.get('rec_name')} "
f"caratteri non consentiti",
"model": submitted_data.get("data_model"),
"data": {},
}
Expand Down Expand Up @@ -332,9 +336,9 @@ async def server_get_action(self, url_action="", modal=False):
url = url_action
server_response = {}
if server_response and (
server_response.get("action")
or server_response.get("content", {}).get("action")
or server_response.get("content", {}).get("reload")
server_response.get("action")
or server_response.get("content", {}).get("action")
or server_response.get("content", {}).get("reload")
):
content = server_response
if "content" in server_response:
Expand Down Expand Up @@ -425,7 +429,7 @@ async def get_list_models(self, domain={}, compute_label="title"):
return data

async def get_remote_data_select(
self, url, path_value, header_key, header_value_key
self, url, path_value, header_key, header_value_key
):
"""
name is a model name
Expand Down Expand Up @@ -466,7 +470,7 @@ async def get_resource_schema_select(self, type, select):
return data

async def complete_json_response(
self, res, orig_resp=None
self, res, orig_resp=None
) -> JSONResponse:
response = JSONResponse(res)
return self.complete_response(response)
Expand Down Expand Up @@ -514,12 +518,14 @@ async def get_remote_object(self, url, headers={}, params={}, cookies={}):

if "token" in params:
cookies = {"authtoken": params.get("token")}
if "cookie" in self.headers:
self.headers.pop("cookie")

if not cookies:
cookies = self.request.cookies.copy()

# logger.info(f" request headers {self.headers}")
logger.debug(f"get_remote_object --> {url}")
logger.debug(f"get_remote_object --> {url} cookies {cookies} self.headers {self.headers}")

async with httpx.AsyncClient(timeout=None) as client:
res = await client.get(
Expand Down Expand Up @@ -550,11 +556,14 @@ async def get_remote_object(self, url, headers={}, params={}, cookies={}):
return {}

async def get_remote_request(
self, url, headers={}, params={}, cookies={}, use_app=True
self, url, headers={}, params={}, cookies={}, use_app=True, service_url=False
):
if use_app:
headers = self.headers

if service_url:
url = f"{self.local_settings.service_url}{url}"

logger.info(f"get_remote_request --> {url}")
logger.info(f" request updated headers before {headers}")
async with httpx.AsyncClient(timeout=None) as client:
Expand All @@ -575,7 +584,7 @@ async def get_remote_request(
return {"status": "error", "msg": res.status_code}

async def post_remote_object(
self, url, data={}, headers={}, params={}, cookies={}
self, url, data={}, headers={}, params={}, cookies={}
):
logger.debug(url)
if self.local_settings.service_url not in url:
Expand Down Expand Up @@ -613,7 +622,7 @@ async def post_remote_object(
return {"status": "error", "msg": f"{url} ERROR {res.status_code}"}

async def post_remote_request(
self, url, data={}, headers={}, params={}, cookies={}, use_app=True
self, url, data={}, headers={}, params={}, cookies={}, use_app=True
):
if use_app:
headers = self.headers.copy()
Expand All @@ -638,7 +647,7 @@ async def post_remote_request(
return {}

async def delete_remote_object(
self, url, data={}, headers={}, params={}, cookies={}
self, url, data={}, headers={}, params={}, cookies={}
):
logger.debug(f"delete_remote_object --> {url}")

Expand Down
Loading

0 comments on commit 5579477

Please sign in to comment.