From a5ab61ba79e2f10288f82ea4772bd3aaddf83ce2 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 22 Aug 2025 16:00:27 -0500 Subject: [PATCH 01/30] Authenticate missing store refresh token --- backend/app/api/endpoints/base/user.py | 7 --- backend/app/api/endpoints/base/userpublic.py | 14 +++++- backend/app/domain/auth/authentication.py | 49 ++++++++++++++++--- backend/app/domain/schemas/auth/auth.py | 5 +- frontends/web/src/common/ApiService.js | 3 +- .../src/new_front/pages/Login/LoginPage.tsx | 7 +-- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/backend/app/api/endpoints/base/user.py b/backend/app/api/endpoints/base/user.py index 9ae7b05a6..948a87ebb 100644 --- a/backend/app/api/endpoints/base/user.py +++ b/backend/app/api/endpoints/base/user.py @@ -4,8 +4,6 @@ from fastapi import APIRouter -from app.domain.auth.authentication import LoginService -from app.domain.schemas.auth.auth import LoginRequest from app.domain.schemas.base.user import UserInfoBadges from app.domain.services.base.user import UserService @@ -21,8 +19,3 @@ async def get_task_id_by_task_code(user_id: str): @router.get("/get_stats_by_user_id/{user_id}", response_model={}) async def get_stats_by_user_id(user_id: str): return UserService().get_stats_by_user_id(user_id) - - -@router.post("/authenticate") -async def authenticate(model: LoginRequest): - return LoginService().login(model.email, model.password) diff --git a/backend/app/api/endpoints/base/userpublic.py b/backend/app/api/endpoints/base/userpublic.py index fa78507a3..154a1efe7 100644 --- a/backend/app/api/endpoints/base/userpublic.py +++ b/backend/app/api/endpoints/base/userpublic.py @@ -2,10 +2,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter +from fastapi import APIRouter, Response from app.domain.auth.authentication import LoginService -from app.domain.schemas.auth.auth import CreateUserRequest +from app.domain.schemas.auth.auth import CreateUserRequest, LoginRequest, LoginResponse router = APIRouter() @@ -14,3 +14,13 @@ @router.post("/create_user") async def create_user(model: CreateUserRequest): return LoginService().create_user(model.email, model.password, model.username) + + +@router.post("/authenticate", response_model=LoginResponse) +async def authenticate(model: LoginRequest, response: Response): + return LoginService().login(model.email, model.password, response) + + +@router.post("/refresh") +async def refresh_token(response: Response): + return LoginService().refresh_token(response) diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index 261cb3c10..1ebc03c85 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -1,9 +1,11 @@ # Copyright (c) MLCommons and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +#TODO: change to self.AUTH_JWT_SECRET_KEY once everything is migrated import os -from datetime import datetime, timedelta +import secrets +from datetime import datetime, timedelta, timezone from typing import Any, Union from jose import jwt @@ -23,11 +25,13 @@ class LoginService: def __init__(self) -> None: - self.AUTH_JWT_SECRET_KEY = os.getenv("AUTH_JWT_SECRET_KEY") + self.AUTH_JWT_SECRET_KEY = os.getenv("JWT_SECRET") + print("AUTH_JWT_SECRET_KEY", self.AUTH_JWT_SECRET_KEY) self.ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("AUTH_ACCESS_TOKEN_EXPIRE_MINUTES") self.REFRESH_TOKEN_EXPIRE_MINUTES = os.getenv( "AUTH_REFRESH_TOKEN_EXPIRE_MINUTES" ) + self.AUTH_COOKIWE_SECRET_KEY = os.getenv("AUTH_COOKIE_SECRET_KEY", "") self.AUTH_HASH_ALGORITHM = os.getenv("AUTH_HASH_ALGORITHM") self.users_service = UserService() self.task_user_permission_repository = TaskUserPermissionRepository() @@ -48,13 +52,31 @@ def create_token( expires_delta: int = None, ) -> str: if expires_delta: - expires_delta = datetime.utcnow() + expires_delta + expires_delta = datetime.now() + expires_delta else: - expires_delta = datetime.utcnow() + timedelta(minutes=int(minutes)) + expires_delta = datetime.now() + timedelta(minutes=int(minutes)) to_encode = {"exp": expires_delta, "sub": str(subject)} encoded_jwt = jwt.encode(to_encode, secret_key, algorithm) return encoded_jwt + + def set_refresh_token(self, response): + """Create a refresh token using secure random generation""" + refresh_token = secrets.token_hex() + cookie_expires = datetime.now(timezone.utc) + timedelta(days=60) + + print(f"Response type: {type(response)}") + print(f"Response has set_cookie: {hasattr(response, 'set_cookie')}") + + response.set_cookie( + key="dynabench_refresh_token", + value=refresh_token, + httponly=True, + path="/", + expires=cookie_expires, + secure=True, + ) + return refresh_token def create_access_token( self, subject: Union[str, Any], expires_delta: int = None @@ -75,18 +97,20 @@ def create_user(self, email: str, password: str, username: str) -> dict: user_id = self.users_service.create_user(email, password, username)["id"] self.badges_repository.add_badge(user_id, "WELCOME_NOOB") - def login(self, email: str, password: str) -> dict: + def login(self, email: str, password: str, response) -> dict: email_provider = email.split("@")[1] - if ["prolific", "amazonturk"] in email_provider: - self.create_user(email, password, email.split("@")[0]) user = self.users_service.get_by_email(email) + if not user and email_provider in ["prolific", "amazonturk"]: + user = self.create_user(email, password, email.split("@")[0]) if user is None: user_does_not_exist() hashed_pass = user["password"] if not self.verify_password(password, hashed_pass): password_is_incorrect() + token = self.create_access_token(user["email"]) + self.set_refresh_token(response) return { - "token": self.create_access_token(user["email"]), + "token": token, "user": user, } @@ -94,3 +118,12 @@ def is_admin_or_owner(self, user_id: int, task_id: int): return self.task_user_permission_repository.is_task_owner( user_id, task_id ) or self.users_service.get_is_admin(user_id) + + def refresh_token(self, response): + """Refresh the JWT token and set a new refresh token cookie""" + new_token = self.create_access_token("refresh") + refresh_token = self.set_refresh_token(response) + return { + "token": new_token, + "refresh_token": refresh_token, + } diff --git a/backend/app/domain/schemas/auth/auth.py b/backend/app/domain/schemas/auth/auth.py index e9b12ea29..5a0c90977 100644 --- a/backend/app/domain/schemas/auth/auth.py +++ b/backend/app/domain/schemas/auth/auth.py @@ -12,7 +12,6 @@ class CreateUserRequest(BaseModel): class CreateUserResponse(BaseModel): email: EmailStr - password: str username: str id: int @@ -23,8 +22,8 @@ class LoginRequest(BaseModel): class LoginResponse(BaseModel): - access_token: str - token_type: str + token: str + user: CreateUserResponse class TokenPayload(BaseModel): diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index c5a96adc3..182c9e39e 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -25,6 +25,7 @@ export default class ApiService { } else { this.domain = domain || "https://www.dynabench.org:8080"; } + this.alternateDomain = process.env.REACT_APP_API_HOST_2 || "https://www.dynabench.org:8080"; this.fetch = this.fetch.bind(this); this.setToken = this.setToken.bind(this); this.getToken = this.getToken.bind(this); @@ -961,7 +962,7 @@ export default class ApiService { } refreshToken() { - return this.doFetch(`${this.domain}/authenticate/refresh`, {}, true).then( + return this.doFetch(`${this.alternateDomain}/authenticate/refresh`, {}, true).then( (result) => { this.setToken(result.token); } diff --git a/frontends/web/src/new_front/pages/Login/LoginPage.tsx b/frontends/web/src/new_front/pages/Login/LoginPage.tsx index 39039f297..ece6bd5a7 100644 --- a/frontends/web/src/new_front/pages/Login/LoginPage.tsx +++ b/frontends/web/src/new_front/pages/Login/LoginPage.tsx @@ -43,7 +43,7 @@ const LoginPage: FC = () => { const handleLogin = () => { axios - .post(`${process.env.REACT_APP_API_HOST}/authenticate`, { + .post(`${process.env.REACT_APP_API_HOST_2}/user_public/authenticate`, { email: email, password: password, }) @@ -52,13 +52,14 @@ const LoginPage: FC = () => { updateState({ user: response.data.user, }); + console.log("Login successful", response); if (taskCode && srcURL) { history.push( - `/tasks/${taskCode}${srcURL}?assignmentId=${assignmentId}&treatmentId=${treatmentId}` + `/tasks/${taskCode}${srcURL}?assignmentId=${assignmentId}&treatmentId=${treatmentId}`, ); } else if (taskCode) { history.push( - `/tasks/${taskCode}/create?assignmentId=${assignmentId}&treatmentId=${treatmentId}` + `/tasks/${taskCode}/create?assignmentId=${assignmentId}&treatmentId=${treatmentId}`, ); } else if (originalPath === "/") { history.push("/account"); From 9dbc60f0ef8451866a678b5fac944b92a64c0017 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 22 Aug 2025 16:04:55 -0500 Subject: [PATCH 02/30] Update_apiService --- frontends/web/src/common/ApiService.js | 83 +++++++++++++------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 182c9e39e..cd7ac6e04 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -25,7 +25,8 @@ export default class ApiService { } else { this.domain = domain || "https://www.dynabench.org:8080"; } - this.alternateDomain = process.env.REACT_APP_API_HOST_2 || "https://www.dynabench.org:8080"; + this.alternateDomain = + process.env.REACT_APP_API_HOST_2 || "https://www.dynabench.org:8080"; this.fetch = this.fetch.bind(this); this.setToken = this.setToken.bind(this); this.getToken = this.getToken.bind(this); @@ -99,7 +100,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(data), }, - includeCredentials + includeCredentials, ); } @@ -220,7 +221,7 @@ export default class ApiService { sort, sortDirection, metricWeights, - datasetWeights + datasetWeights, ) { const pageQuery = `limit=${limit || 10}&offset=${offset || 0}`; const sortQuery = @@ -234,7 +235,7 @@ export default class ApiService { const datasetWeightsQuery = datasetWeights ? `&ordered_scoring_dataset_weights=${encodeURIComponent( - datasetWeights.join("|") + datasetWeights.join("|"), )}` : ""; @@ -297,7 +298,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -314,12 +315,12 @@ export default class ApiService { `${ this.domain }/contexts/${tid}/${rid}/${method}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}`, { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -328,7 +329,7 @@ export default class ApiService { rid, tags = [], context_tags = [], - annotator_id = null + annotator_id = null, ) { const includeCredentials = this.mode !== "mturk"; @@ -339,12 +340,12 @@ export default class ApiService { : ""; return this.doFetch( `${this.domain}/examples/${tid}/${rid}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}${context_tags_query}${annotator_query}`, { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -355,17 +356,17 @@ export default class ApiService { maxNumFlags, minNumDisagreements, maxNumDisagreements, - tags = [] + tags = [], ) { return this.fetch( `${ this.domain }/examples/${tid}/${rid}/filtered/${minNumFlags}/${maxNumFlags}/${minNumDisagreements}/${maxNumDisagreements}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}`, { method: "GET", - } + }, ); } @@ -385,7 +386,7 @@ export default class ApiService { `${this.domain}/notifications?limit=${limit || 10}&offset=${offset || 0}`, { method: "GET", - } + }, ); } @@ -396,7 +397,7 @@ export default class ApiService { }`, { method: "GET", - } + }, ); } @@ -407,7 +408,7 @@ export default class ApiService { }`, { method: "GET", - } + }, ); } @@ -416,7 +417,7 @@ export default class ApiService { `${this.domain}/users/${userId}/forks?limit=${limit}&offset=${offset}`, { method: "GET", - } + }, ); } @@ -425,7 +426,7 @@ export default class ApiService { `${this.domain}/users/${userId}/snapshots?limit=${limit}&offset=${offset}`, { method: "GET", - } + }, ); } @@ -497,7 +498,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -522,7 +523,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials + includeCredentials, ); } @@ -542,7 +543,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials + includeCredentials, ); } @@ -564,7 +565,7 @@ export default class ApiService { method: "POST", body: JSON.stringify(data), }, - false + false, ); } @@ -585,7 +586,7 @@ export default class ApiService { `${this.domain}/task_proposals/all/${page}/${pageLimit}`, { method: "GET", - } + }, ); } @@ -594,7 +595,7 @@ export default class ApiService { `${this.domain}/task_proposals/user/${page}/${pageLimit}`, { method: "GET", - } + }, ); } @@ -660,7 +661,7 @@ export default class ApiService { { method: "PUT", body: JSON.stringify(data), - } + }, ); } @@ -745,7 +746,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - } + }, ); } @@ -763,7 +764,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - } + }, ); } @@ -778,7 +779,7 @@ export default class ApiService { metadata, modelWrong, tag = null, - modelEndpointName = null + modelEndpointName = null, ) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( @@ -799,7 +800,7 @@ export default class ApiService { model_endpoint_name: modelEndpointName, }), }, - includeCredentials + includeCredentials, ); } @@ -819,7 +820,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_configuration/${name}`, { method: "GET", - } + }, ); } @@ -832,7 +833,7 @@ export default class ApiService { orderedDatasetWeights, totalCount, description, - name + name, ) { return this.fetch(`${this.domain}/tasks/${tid}/leaderboard_snapshot`, { method: "PUT", @@ -854,7 +855,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_snapshot/${name}`, { method: "GET", - } + }, ); } @@ -863,7 +864,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/disambiguate_forks_and_snapshots/${name}`, { method: "GET", - } + }, ); } @@ -887,7 +888,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; return false; - } + }, ); } } @@ -962,11 +963,13 @@ export default class ApiService { } refreshToken() { - return this.doFetch(`${this.alternateDomain}/authenticate/refresh`, {}, true).then( - (result) => { - this.setToken(result.token); - } - ); + return this.doFetch( + `${this.alternateDomain}/authenticate/refresh`, + {}, + true, + ).then((result) => { + this.setToken(result.token); + }); } doFetch(url, options, includeCredentials = false) { @@ -1009,7 +1012,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; throw error; - } + }, ); } return this.doFetch(url, options, {}, true); From d9f201f6c9ec388ce2e366a44c4583966fdff1a8 Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 3 Nov 2025 16:38:41 -0500 Subject: [PATCH 03/30] Include all the refresh token logic on this backend and fix up things for them to work with this --- backend/app/api/endpoints/auth.py | 38 +++- backend/app/api/endpoints/base/model.py | 14 +- backend/app/api/endpoints/base/user.py | 10 +- backend/app/api/endpoints/base/userpublic.py | 26 --- .../{authorization.py => authentication.py} | 9 +- backend/app/domain/auth/authentication.py | 215 +++++++++++++++--- backend/app/domain/helpers/exceptions.py | 8 + .../infrastructure/repositories/abstract.py | 6 + .../repositories/refreshtoken.py | 14 ++ backend/app/main.py | 24 +- 10 files changed, 284 insertions(+), 80 deletions(-) delete mode 100644 backend/app/api/endpoints/base/userpublic.py rename backend/app/api/middleware/{authorization.py => authentication.py} (89%) diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index ee96950a1..f0e3f7cfa 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -2,14 +2,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter, Depends -from fastapi.security import OAuth2PasswordRequestForm +from fastapi import APIRouter, Depends, Header, Request, Response +from app.api.middleware.authentication import validate_access_token from app.domain.auth.authentication import LoginService from app.domain.schemas.auth.auth import ( CreateUserRequest, - CreateUserResponse, IsAdminOrOwnerRequest, + LoginRequest, LoginResponse, ) @@ -17,16 +17,34 @@ router = APIRouter() -@router.post("/login", response_model=LoginResponse) -async def login(model: OAuth2PasswordRequestForm = Depends()): - return LoginService().login(model.username, model.password) - - -@router.post("/create_user", response_model=CreateUserResponse) +@router.post("/create_user") async def create_user(model: CreateUserRequest): return LoginService().create_user(model.email, model.password, model.username) @router.post("/is_admin_or_owner", response_model=bool) -async def is_admin_or_owner(model: IsAdminOrOwnerRequest): +async def is_admin_or_owner( + model: IsAdminOrOwnerRequest, token_payload=Depends(validate_access_token) +): return LoginService().is_admin_or_owner(model.user_id, model.task_id) + + +@router.get("/refresh") +async def refresh_token( + request: Request, + response: Response, + authorization: str = Header(..., description="Bearer token required"), +): + return LoginService().refresh_token(request, response, authorization) + + +@router.post("/login", response_model=LoginResponse) +async def login(model: LoginRequest, response: Response): + return LoginService().login(model.email, model.password, response) + + +@router.post("/logout") +async def logout( + request: Request, response: Response, token_payload=Depends(validate_access_token) +): + return LoginService().logout(request, response) diff --git a/backend/app/api/endpoints/base/model.py b/backend/app/api/endpoints/base/model.py index 9ba861a9c..f559ca51c 100644 --- a/backend/app/api/endpoints/base/model.py +++ b/backend/app/api/endpoints/base/model.py @@ -1,7 +1,15 @@ # Copyright (c) MLCommons and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter, BackgroundTasks, Depends, File, Response, UploadFile +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + File, + Request, + Response, + UploadFile, +) from fastapi.responses import FileResponse from app.domain.schemas.base.model import ( @@ -231,7 +239,9 @@ def update_model_status(model_id: int): @router.get("/get_models_by_user_id/{user_id}") -def get_models_by_user_id(user_id: int): +def get_models_by_user_id(user_id: int, request: Request): + if user_id != request.state.user: + raise PermissionError("Unauthorized access to model data.") return ModelService().get_models_by_user_id(user_id) diff --git a/backend/app/api/endpoints/base/user.py b/backend/app/api/endpoints/base/user.py index 948a87ebb..c9b624455 100644 --- a/backend/app/api/endpoints/base/user.py +++ b/backend/app/api/endpoints/base/user.py @@ -2,7 +2,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter +from fastapi import APIRouter, Request from app.domain.schemas.base.user import UserInfoBadges from app.domain.services.base.user import UserService @@ -12,10 +12,14 @@ @router.get("/get_user_with_badges/{user_id}", response_model=UserInfoBadges) -async def get_task_id_by_task_code(user_id: str): +async def get_task_id_by_task_code(user_id: int, request: Request): + if user_id != request.state.user: + raise PermissionError("Unauthorized access to user data.") return UserService().get_user_with_badges(user_id) @router.get("/get_stats_by_user_id/{user_id}", response_model={}) -async def get_stats_by_user_id(user_id: str): +async def get_stats_by_user_id(user_id: int, request: Request): + if user_id != request.state.user: + raise PermissionError("Unauthorized access to user data.") return UserService().get_stats_by_user_id(user_id) diff --git a/backend/app/api/endpoints/base/userpublic.py b/backend/app/api/endpoints/base/userpublic.py deleted file mode 100644 index 154a1efe7..000000000 --- a/backend/app/api/endpoints/base/userpublic.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) MLCommons and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from fastapi import APIRouter, Response - -from app.domain.auth.authentication import LoginService -from app.domain.schemas.auth.auth import CreateUserRequest, LoginRequest, LoginResponse - - -router = APIRouter() - - -@router.post("/create_user") -async def create_user(model: CreateUserRequest): - return LoginService().create_user(model.email, model.password, model.username) - - -@router.post("/authenticate", response_model=LoginResponse) -async def authenticate(model: LoginRequest, response: Response): - return LoginService().login(model.email, model.password, response) - - -@router.post("/refresh") -async def refresh_token(response: Response): - return LoginService().refresh_token(response) diff --git a/backend/app/api/middleware/authorization.py b/backend/app/api/middleware/authentication.py similarity index 89% rename from backend/app/api/middleware/authorization.py rename to backend/app/api/middleware/authentication.py index 396ee08b2..1b6eb8699 100644 --- a/backend/app/api/middleware/authorization.py +++ b/backend/app/api/middleware/authentication.py @@ -22,6 +22,7 @@ async def verify_token(token: str): token, os.getenv("JWT_SECRET"), algorithms=[os.getenv("AUTH_HASH_ALGORITHM")], + options={"verify_exp": True}, ) return decoded_token except Exception as e: @@ -45,9 +46,11 @@ async def validate_access_token( # The OAuth2PasswordBearer already extracts and validates the Bearer token format # So we don't need to manually extract it from headers decoded: AccessTokenPayload = await verify_token(token) + refresh_token = request.cookies.get("dynabench_refresh_token", None) + if not refresh_token: + raise credentials_exception() # While we migrate the login into Backend we are using id # That is what the API sends in the token. - # email = decoded.get("email", None) id = decoded.get("id", None) if not id: raise credentials_exception() @@ -55,5 +58,5 @@ async def validate_access_token( # Once we have mirated the login into Backend we will use email. # user = repository.get_by_email(email) if user is None or not user: - raise credentials_exception() - request.state.user = user + credentials_exception() + request.state.user = user["id"] diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index 1ebc03c85..9298adb4a 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -1,41 +1,48 @@ # Copyright (c) MLCommons and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -#TODO: change to self.AUTH_JWT_SECRET_KEY once everything is migrated import os import secrets from datetime import datetime, timedelta, timezone from typing import Any, Union +from fastapi import HTTPException, status from jose import jwt from werkzeug.security import check_password_hash, generate_password_hash from app.domain.helpers.exceptions import ( + credentials_exception, password_is_incorrect, + refresh_token_expired, user_does_not_exist, user_with_email_already_exists, ) from app.domain.services.base.user import UserService from app.infrastructure.repositories.badge import BadgeRepository +from app.infrastructure.repositories.refreshtoken import RefreshTokenRepository from app.infrastructure.repositories.taskuserpermission import ( TaskUserPermissionRepository, ) +from app.infrastructure.repositories.user import UserRepository class LoginService: def __init__(self) -> None: self.AUTH_JWT_SECRET_KEY = os.getenv("JWT_SECRET") - print("AUTH_JWT_SECRET_KEY", self.AUTH_JWT_SECRET_KEY) - self.ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("AUTH_ACCESS_TOKEN_EXPIRE_MINUTES") - self.REFRESH_TOKEN_EXPIRE_MINUTES = os.getenv( - "AUTH_REFRESH_TOKEN_EXPIRE_MINUTES" + self.ACCESS_TOKEN_EXPIRE_MINUTES = int( + os.getenv("AUTH_ACCESS_TOKEN_EXPIRE_MINUTES") ) - self.AUTH_COOKIWE_SECRET_KEY = os.getenv("AUTH_COOKIE_SECRET_KEY", "") - self.AUTH_HASH_ALGORITHM = os.getenv("AUTH_HASH_ALGORITHM") + self.REFRESH_TOKEN_EXPIRE_DAYS = int( + os.getenv("AUTH_REFRESH_TOKEN_EXPIRE_DAYS") + ) + self.AUTH_COOKIE_SECRET_KEY = os.getenv("AUTH_COOKIE_SECRET_KEY", "") + self.AUTH_HASH_ALGORITHM = os.getenv("AUTH_HASH_ALGORITHM", "HS256") self.users_service = UserService() self.task_user_permission_repository = TaskUserPermissionRepository() self.badges_repository = BadgeRepository() + self.refresh_token_repository = RefreshTokenRepository() + self.users_repository = UserRepository() def get_hashed_password(self, password: str) -> str: return generate_password_hash(password) @@ -52,21 +59,30 @@ def create_token( expires_delta: int = None, ) -> str: if expires_delta: - expires_delta = datetime.now() + expires_delta + expires_delta = datetime.now(timezone.utc) + expires_delta else: - expires_delta = datetime.now() + timedelta(minutes=int(minutes)) + expires_delta = datetime.now(timezone.utc) + timedelta(minutes=int(minutes)) - to_encode = {"exp": expires_delta, "sub": str(subject)} + to_encode = {"exp": expires_delta, **subject} encoded_jwt = jwt.encode(to_encode, secret_key, algorithm) return encoded_jwt - - def set_refresh_token(self, response): + + def set_refresh_token(self, response, user_id: int) -> str: """Create a refresh token using secure random generation""" - refresh_token = secrets.token_hex() - cookie_expires = datetime.now(timezone.utc) + timedelta(days=60) + refresh_token = secrets.token_hex(32) + cookie_expires = datetime.now(timezone.utc) + timedelta( + days=self.REFRESH_TOKEN_EXPIRE_DAYS + ) - print(f"Response type: {type(response)}") - print(f"Response has set_cookie: {hasattr(response, 'set_cookie')}") + self.cleanup_old_refresh_tokens(user_id) + + self.refresh_token_repository.add( + { + "token": refresh_token, + "uid": user_id, + "generated_datetime": datetime.now(timezone.utc), + } + ) response.set_cookie( key="dynabench_refresh_token", @@ -74,7 +90,10 @@ def set_refresh_token(self, response): httponly=True, path="/", expires=cookie_expires, - secure=True, + # For localhost testing set secure to False + secure=False, + # For Localhost testing set samesite to None, else lax + samesite="lax", ) return refresh_token @@ -107,23 +126,165 @@ def login(self, email: str, password: str, response) -> dict: hashed_pass = user["password"] if not self.verify_password(password, hashed_pass): password_is_incorrect() - token = self.create_access_token(user["email"]) - self.set_refresh_token(response) + token = self.create_access_token({"id": user["id"]}) + self.set_refresh_token(response, user["id"]) return { "token": token, "user": user, } + def logout(self, request, response) -> dict: + """ + Logout the user by deleting the refresh token from cookies and database + + Args: + request: Request object to extract cookies + response: Response object to delete refresh token cookie + """ + current_refresh_token = self.get_refresh_token_from_cookie(request) + db_token = self.refresh_token_repository.get_by_token(current_refresh_token) + + if db_token: + uid = db_token.get("uid", None) + else: + credentials_exception() + + user = self.users_repository.get_by_id(uid) + if not user or user["id"] != request.state.user: + raise credentials_exception() + if current_refresh_token and db_token: + self.refresh_token_repository.delete(db_token["id"]) + response.delete_cookie("dynabench_refresh_token") + return {"message": "Logged out successfully"} + else: + refresh_token_expired() + def is_admin_or_owner(self, user_id: int, task_id: int): return self.task_user_permission_repository.is_task_owner( user_id, task_id ) or self.users_service.get_is_admin(user_id) - def refresh_token(self, response): - """Refresh the JWT token and set a new refresh token cookie""" - new_token = self.create_access_token("refresh") - refresh_token = self.set_refresh_token(response) - return { - "token": new_token, - "refresh_token": refresh_token, - } + def refresh_token(self, request, response, authorization_header: str) -> dict: + """ + Refresh the JWT token using the refresh token from cookie + + Args: + request: Request object to extract cookies + response: Response object to set new refresh token cookie + authorization_header: Optional current access token to get user info + + Returns: + Dict with new access token and user info + """ + try: + current_refresh_token = None + user_id = None + + # Step 1: Validate that authorization header is provided + if not authorization_header: + raise Exception("Invalid or expired bearer token") + + current_refresh_token = self.get_refresh_token_from_cookie(request) + + if not current_refresh_token: + refresh_token_expired() + + # Step 2: Find user by refresh token in database and validate if user owns it + db_token = self.refresh_token_repository.get_by_token(current_refresh_token) + if not db_token: + refresh_token_expired() + + user_id = db_token["uid"] + # Step 3: Validate user exists + user = self.users_repository.get_by_id(user_id) + if not user: + raise Exception("User not found") + + # Step 4: Validate refresh token is not expired and the user from the token matches the user + if not self.validate_refresh_token_in_db( + db_token, authorization_header, user + ): + refresh_token_expired() + + # Step 5: Create new access token + new_access_token = self.create_access_token({"id": user["id"]}) + + # Step 6: Create new refresh token and store it + self.set_refresh_token(response, user_id) + + return { + "token": new_access_token, + "message": "Token refreshed successfully", + } + + except Exception as e: + # Clean up any invalid tokens + if "current_refresh_token" in locals(): + try: + db_token = self.refresh_token_repository.get_by_token( + current_refresh_token + ) + if db_token: + self.refresh_token_repository.delete(db_token["id"]) + except Exception: + pass + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Token refresh failed: {str(e)}", + headers={"WWW-Authenticate": "Bearer"}, + ) + + def get_refresh_token_from_cookie(self, request) -> str: + """Extract refresh token from HTTP-only cookie""" + return request.cookies.get("dynabench_refresh_token", None) + + def validate_refresh_token_in_db( + self, refresh_token: dict, authorization_header: str, user: dict + ) -> bool: + """Validate if refresh token exists in database and is not expired also if it is from the user""" + try: + # Check if token is expired (assuming 60 days expiration) + # Maybe verify the age from the token itself? + deadline = refresh_token.get("generated_datetime", None) + if not deadline: + return False + + if deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + token_age = datetime.now(timezone.utc) - deadline + + if token_age.days > self.REFRESH_TOKEN_EXPIRE_DAYS: + # Clean up expired token + self.refresh_token_repository.delete(refresh_token["id"]) + return False + + access_token = authorization_header[7:] + payload = jwt.decode( + access_token, + self.AUTH_JWT_SECRET_KEY, + algorithms=[self.AUTH_HASH_ALGORITHM], + options={"verify_exp": False}, # Allow expired tokens for refresh + ) + + payload_user_id = payload.get("id", None) + user_id = user.get("id", None) + + if payload_user_id != user_id or payload_user_id is None: + return False + + return True + except (jwt.JWTError, ValueError, KeyError, Exception) as e: + print(f"Token validation error: {e}") + return False + + def cleanup_old_refresh_tokens(self, user_id: int): + """Remove old refresh tokens for the user (keep only the latest)""" + try: + old_tokens = self.refresh_token_repository.get_all_by_user_id(user_id) + for token in old_tokens: + self.refresh_token_repository.delete(token["id"]) + except Exception as e: + print(f"Error cleaning up old refresh tokens: {e}") + pass diff --git a/backend/app/domain/helpers/exceptions.py b/backend/app/domain/helpers/exceptions.py index 31da51598..f2c5d5c0b 100644 --- a/backend/app/domain/helpers/exceptions.py +++ b/backend/app/domain/helpers/exceptions.py @@ -38,3 +38,11 @@ def bad_token() -> HTTPException: detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) + + +def refresh_token_expired() -> HTTPException: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token has expired", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/backend/app/infrastructure/repositories/abstract.py b/backend/app/infrastructure/repositories/abstract.py index 3f41ca501..6bca4dced 100644 --- a/backend/app/infrastructure/repositories/abstract.py +++ b/backend/app/infrastructure/repositories/abstract.py @@ -35,3 +35,9 @@ def get_by_id(self, id: int) -> dict: instance = self.session.query(self.model).get(id) instance = self.instance_converter.instance_to_dict(instance) return instance + + def delete(self, id: int) -> None: + instance = self.session.query(self.model).get(id) + with self.session as session: + session.delete(instance) + session.commit() diff --git a/backend/app/infrastructure/repositories/refreshtoken.py b/backend/app/infrastructure/repositories/refreshtoken.py index 8ceb80f9b..bf64642e5 100644 --- a/backend/app/infrastructure/repositories/refreshtoken.py +++ b/backend/app/infrastructure/repositories/refreshtoken.py @@ -13,3 +13,17 @@ class RefreshTokenRepository(AbstractRepository): def __init__(self) -> None: super().__init__(RefreshToken) + + def get_by_token(self, token: str) -> RefreshToken: + instance = ( + self.session.query(RefreshToken).filter(RefreshToken.token == token).first() + ) + return self.instance_converter.instance_to_dict(instance) + + def get_all_by_user_id(self, user_id: int): + instances = ( + self.session.query(RefreshToken).filter(RefreshToken.uid == user_id).all() + ) + return [ + self.instance_converter.instance_to_dict(instance) for instance in instances + ] diff --git a/backend/app/main.py b/backend/app/main.py index 6ce735695..acec997ee 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -24,10 +24,9 @@ task, task_proposals, user, - userpublic, ) from app.api.endpoints.builder_and_evaluation import evaluation -from app.api.middleware.authorization import validate_access_token +from app.api.middleware.authentication import validate_access_token load_dotenv() @@ -38,14 +37,26 @@ "http://localhost:3000", "https://www.dynabench.org", "https://front-dev.dynabench.org", + # "postman://app", include this only when testing with postman ] app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=origins, allow_credentials=True, allow_methods=["*"], - allow_headers=["*"], + allow_headers=[ + "Accept", + "Accept-Language", + "Content-Language", + "Content-Type", + "Authorization", + "Cookie", + "Set-Cookie", + "Access-Control-Allow-Headers", + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Origin", + ], ) @@ -92,11 +103,6 @@ def read_root(): tags=["user"], dependencies=[Depends(validate_access_token)], ) -app.include_router( - userpublic.router, - prefix="/user_public", - tags=["user_public"], -) app.include_router( task_proposals.router, prefix="/task_proposals", tags=["task_proposals"] ) From 03a6467426ad18758600ec42e7ef6becd7075fee Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 3 Nov 2025 17:38:31 -0500 Subject: [PATCH 04/30] include changes in the frontEnd --- frontends/web/src/common/ApiService.js | 14 ++-- frontends/web/src/containers/App.js | 2 + .../src/new_front/pages/Login/LoginPage.tsx | 69 +++++++++++++------ .../src/new_front/pages/Login/Register.tsx | 2 +- 4 files changed, 56 insertions(+), 31 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index cd7ac6e04..e1addc46b 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -924,7 +924,7 @@ export default class ApiService { logout() { try { - this.fetch(`${this.domain}/authenticate/logout`, { + this.fetch(`${this.alternateDomain}/auth/logout`, { method: "POST", }); } catch (e) { @@ -963,13 +963,11 @@ export default class ApiService { } refreshToken() { - return this.doFetch( - `${this.alternateDomain}/authenticate/refresh`, - {}, - true, - ).then((result) => { - this.setToken(result.token); - }); + return this.doFetch(`${this.alternateDomain}/auth/refresh`, {}, true).then( + (result) => { + this.setToken(result.token); + }, + ); } doFetch(url, options, includeCredentials = false) { diff --git a/frontends/web/src/containers/App.js b/frontends/web/src/containers/App.js index 7c8958644..e01bef207 100644 --- a/frontends/web/src/containers/App.js +++ b/frontends/web/src/containers/App.js @@ -178,6 +178,8 @@ class App extends React.Component { "Content-Type": "application/json", Authorization: `Bearer ${this.api.getToken()}`, }, + credentials: "include", + mode: "cors", }} > diff --git a/frontends/web/src/new_front/pages/Login/LoginPage.tsx b/frontends/web/src/new_front/pages/Login/LoginPage.tsx index ece6bd5a7..c1edbc5a9 100644 --- a/frontends/web/src/new_front/pages/Login/LoginPage.tsx +++ b/frontends/web/src/new_front/pages/Login/LoginPage.tsx @@ -1,6 +1,6 @@ import React, { FC, useContext, useEffect, useState } from "react"; import { Link, useHistory, useLocation } from "react-router-dom"; -import { Button } from "react-bootstrap"; +import { Button, Spinner } from "react-bootstrap"; import axios from "axios"; import Swal from "sweetalert2"; import { useTranslation } from "react-i18next"; @@ -11,11 +11,13 @@ import { ReactComponent as Login } from "new_front/assets/login.svg"; const LoginPage: FC = () => { const [email, setEmail] = useState(""); const [password, setPassword] = useState(""); - const { updateState } = useContext(UserContext); + const { updateState: updateUserContext } = useContext(UserContext); const originalPath = localStorage.getItem("originalPath"); const history = useHistory(); const location = useLocation(); const { t } = useTranslation(); + const [loading, setLoading] = useState(false); + const [localSetted, setLocalSetted] = useState(false); const queryParams = new URLSearchParams(location.search); const assignmentId = @@ -42,37 +44,35 @@ const LoginPage: FC = () => { }, []); const handleLogin = () => { + setLoading(true); axios - .post(`${process.env.REACT_APP_API_HOST_2}/user_public/authenticate`, { - email: email, - password: password, - }) + .post( + `${process.env.REACT_APP_API_HOST_2}/auth/login`, + { + email: email, + password: password, + }, + { + withCredentials: true, + } + ) .then((response) => { localStorage.setItem("id_token", response.data.token); - updateState({ + updateUserContext({ user: response.data.user, }); - console.log("Login successful", response); - if (taskCode && srcURL) { - history.push( - `/tasks/${taskCode}${srcURL}?assignmentId=${assignmentId}&treatmentId=${treatmentId}`, - ); - } else if (taskCode) { - history.push( - `/tasks/${taskCode}/create?assignmentId=${assignmentId}&treatmentId=${treatmentId}`, - ); - } else if (originalPath === "/") { - history.push("/account"); - } else { - history.goBack(); - } + setLocalSetted(true); }) .catch((error) => { + console.error("Login error", error); Swal.fire({ icon: "error", title: "Oops...", text: "Something went wrong! try another email or password", }); + }) + .finally(() => { + setLoading(false); }); }; @@ -85,6 +85,26 @@ const LoginPage: FC = () => { } }, [email]); + useEffect(() => { + if (localSetted) { + if (localStorage.getItem("id_token")) { + if (taskCode && srcURL) { + history.push( + `/tasks/${taskCode}${srcURL}?assignmentId=${assignmentId}&treatmentId=${treatmentId}` + ); + } else if (taskCode) { + history.push( + `/tasks/${taskCode}/create?assignmentId=${assignmentId}&treatmentId=${treatmentId}` + ); + } else if (originalPath === "/") { + history.push("/account"); + } else { + history.goBack(); + } + } + } + }, [localSetted]); + return (
@@ -140,8 +160,13 @@ const LoginPage: FC = () => {
diff --git a/frontends/web/src/new_front/pages/Login/Register.tsx b/frontends/web/src/new_front/pages/Login/Register.tsx index fed357ee4..5cf0bf198 100644 --- a/frontends/web/src/new_front/pages/Login/Register.tsx +++ b/frontends/web/src/new_front/pages/Login/Register.tsx @@ -39,7 +39,6 @@ const Register = () => { title: "Success", text: "User created successfully", }); - setSubmitting(false); window.location.href = "/"; } else { Swal.fire({ @@ -48,6 +47,7 @@ const Register = () => { text: "Something went wrong! try another email or username", }); } + setSubmitting(false); }; const renderError = (message: string) => ( From 7fb43ae0ec9a45b86970b7d38a71279fdacea783 Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 3 Nov 2025 18:03:02 -0500 Subject: [PATCH 05/30] Run prettier in ApiService --- frontends/web/src/common/ApiService.js | 70 +++++++++++++------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index e1addc46b..0567779fc 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -100,7 +100,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(data), }, - includeCredentials, + includeCredentials ); } @@ -221,7 +221,7 @@ export default class ApiService { sort, sortDirection, metricWeights, - datasetWeights, + datasetWeights ) { const pageQuery = `limit=${limit || 10}&offset=${offset || 0}`; const sortQuery = @@ -235,7 +235,7 @@ export default class ApiService { const datasetWeightsQuery = datasetWeights ? `&ordered_scoring_dataset_weights=${encodeURIComponent( - datasetWeights.join("|"), + datasetWeights.join("|") )}` : ""; @@ -298,7 +298,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -315,12 +315,12 @@ export default class ApiService { `${ this.domain }/contexts/${tid}/${rid}/${method}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}`, { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -329,7 +329,7 @@ export default class ApiService { rid, tags = [], context_tags = [], - annotator_id = null, + annotator_id = null ) { const includeCredentials = this.mode !== "mturk"; @@ -340,12 +340,12 @@ export default class ApiService { : ""; return this.doFetch( `${this.domain}/examples/${tid}/${rid}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}${context_tags_query}${annotator_query}`, { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -356,17 +356,17 @@ export default class ApiService { maxNumFlags, minNumDisagreements, maxNumDisagreements, - tags = [], + tags = [] ) { return this.fetch( `${ this.domain }/examples/${tid}/${rid}/filtered/${minNumFlags}/${maxNumFlags}/${minNumDisagreements}/${maxNumDisagreements}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}`, { method: "GET", - }, + } ); } @@ -386,7 +386,7 @@ export default class ApiService { `${this.domain}/notifications?limit=${limit || 10}&offset=${offset || 0}`, { method: "GET", - }, + } ); } @@ -397,7 +397,7 @@ export default class ApiService { }`, { method: "GET", - }, + } ); } @@ -408,7 +408,7 @@ export default class ApiService { }`, { method: "GET", - }, + } ); } @@ -417,7 +417,7 @@ export default class ApiService { `${this.domain}/users/${userId}/forks?limit=${limit}&offset=${offset}`, { method: "GET", - }, + } ); } @@ -426,7 +426,7 @@ export default class ApiService { `${this.domain}/users/${userId}/snapshots?limit=${limit}&offset=${offset}`, { method: "GET", - }, + } ); } @@ -498,7 +498,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -523,7 +523,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials, + includeCredentials ); } @@ -543,7 +543,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials, + includeCredentials ); } @@ -565,7 +565,7 @@ export default class ApiService { method: "POST", body: JSON.stringify(data), }, - false, + false ); } @@ -586,7 +586,7 @@ export default class ApiService { `${this.domain}/task_proposals/all/${page}/${pageLimit}`, { method: "GET", - }, + } ); } @@ -595,7 +595,7 @@ export default class ApiService { `${this.domain}/task_proposals/user/${page}/${pageLimit}`, { method: "GET", - }, + } ); } @@ -661,7 +661,7 @@ export default class ApiService { { method: "PUT", body: JSON.stringify(data), - }, + } ); } @@ -746,7 +746,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - }, + } ); } @@ -764,7 +764,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - }, + } ); } @@ -779,7 +779,7 @@ export default class ApiService { metadata, modelWrong, tag = null, - modelEndpointName = null, + modelEndpointName = null ) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( @@ -800,7 +800,7 @@ export default class ApiService { model_endpoint_name: modelEndpointName, }), }, - includeCredentials, + includeCredentials ); } @@ -820,7 +820,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_configuration/${name}`, { method: "GET", - }, + } ); } @@ -833,7 +833,7 @@ export default class ApiService { orderedDatasetWeights, totalCount, description, - name, + name ) { return this.fetch(`${this.domain}/tasks/${tid}/leaderboard_snapshot`, { method: "PUT", @@ -855,7 +855,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_snapshot/${name}`, { method: "GET", - }, + } ); } @@ -864,7 +864,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/disambiguate_forks_and_snapshots/${name}`, { method: "GET", - }, + } ); } @@ -888,7 +888,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; return false; - }, + } ); } } @@ -966,7 +966,7 @@ export default class ApiService { return this.doFetch(`${this.alternateDomain}/auth/refresh`, {}, true).then( (result) => { this.setToken(result.token); - }, + } ); } @@ -1010,7 +1010,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; throw error; - }, + } ); } return this.doFetch(url, options, {}, true); From 675316fd942f12ba20bf27558d484d7d7ec728da Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 3 Nov 2025 18:25:55 -0500 Subject: [PATCH 06/30] put back the name AUTH_JWT_SECRET_KEY for the environment variable --- backend/app/domain/auth/authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index 9298adb4a..68309cfae 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -29,7 +29,7 @@ class LoginService: def __init__(self) -> None: - self.AUTH_JWT_SECRET_KEY = os.getenv("JWT_SECRET") + self.AUTH_JWT_SECRET_KEY = os.getenv("AUTH_JWT_SECRET_KEY") self.ACCESS_TOKEN_EXPIRE_MINUTES = int( os.getenv("AUTH_ACCESS_TOKEN_EXPIRE_MINUTES") ) From dceadfeb5b6edfe98992cd30f3da31aef5b48581 Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 3 Nov 2025 18:27:39 -0500 Subject: [PATCH 07/30] put back the name AUTH_JWT_SECRET_KEY for the environment variable --- backend/app/api/middleware/authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/api/middleware/authentication.py b/backend/app/api/middleware/authentication.py index 1b6eb8699..c467ded51 100644 --- a/backend/app/api/middleware/authentication.py +++ b/backend/app/api/middleware/authentication.py @@ -20,7 +20,7 @@ async def verify_token(token: str): try: decoded_token = jwt.decode( token, - os.getenv("JWT_SECRET"), + os.getenv("AUTH_JWT_SECRET_KEY"), algorithms=[os.getenv("AUTH_HASH_ALGORITHM")], options={"verify_exp": True}, ) From bb0bf0f79a1cbfe51811d29f7f746b15444ff4b1 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 7 Nov 2025 03:00:47 -0500 Subject: [PATCH 08/30] move is admin_or_owner to backend and fix refresh in ProfilePage --- .../common/Annotation/ValidateInterface.js | 2 +- frontends/web/src/common/ApiService.js | 79 ++++++++++--------- frontends/web/src/containers/App.js | 1 + frontends/web/src/containers/ModelPage.js | 2 +- frontends/web/src/containers/TaskOwnerPage.js | 2 +- frontends/web/src/containers/TaskPage.js | 2 +- .../pages/ProfilePage/ProfilePage.tsx | 53 ++++++++----- .../web/src/new_front/pages/Task/TaskPage.tsx | 40 ++++------ 8 files changed, 99 insertions(+), 82 deletions(-) diff --git a/frontends/web/src/common/Annotation/ValidateInterface.js b/frontends/web/src/common/Annotation/ValidateInterface.js index 89de52ca0..ce9fd4993 100644 --- a/frontends/web/src/common/Annotation/ValidateInterface.js +++ b/frontends/web/src/common/Annotation/ValidateInterface.js @@ -106,7 +106,7 @@ class ValidateInterface extends React.Component { this.context.api .getAdminOrOwner(this.state.task.id) .then((result) => { - this.setState({ admin_or_owner: result.admin_or_owner }); + this.setState({ admin_or_owner: result }); }); // eslint-disable-next-line react/no-direct-mutation-state this.state.task.selected_round = this.state.task.cur_round; diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 0567779fc..5b2c80977 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -35,6 +35,7 @@ export default class ApiService { this.getCredentials = this.getCredentials.bind(this); this.setMturkMode = this.setMturkMode.bind(this); this.updating_already = false; + this.refreshPromise = null; this.mode = "normal"; this.exportDatasetLog = this.exportDatasetLog.bind(this); this.exportPrediction = this.exportPrediction.bind(this); @@ -494,7 +495,7 @@ export default class ApiService { getAdminOrOwner(tid) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( - `${this.domain}/tasks/admin_or_owner/${tid}`, + `${this.alternateDomain}/auth/is_admin_or_owner/${tid}`, { method: "GET", }, @@ -916,7 +917,8 @@ export default class ApiService { // handle access not allowed to localStorage if disabled in browser. // https://stackoverflow.com/questions/16427636/check-if-localstorage-is-available try { - return localStorage.getItem("id_token"); + const token = localStorage.getItem("id_token"); + return token; } catch (e) { return null; } @@ -939,26 +941,27 @@ export default class ApiService { return this.getToken() ? decode(this.getToken()) : {}; } - refreshTokenWrapper(callback, error) { - if (this.updating_already) { - // TODO: Make this actually wait for an event? - return delay(1048576).then(() => { - if (this.updating_already) { - return this.refreshTokenWrapper(callback, error); - } + async refreshTokenWrapper(callback, errorCallback) { + if (this.refreshPromise) { + try { + await this.refreshPromise; return callback(); + } catch (error) { + return errorCallback(); + } + } + + // Start new refresh + this.refreshPromise = this.refreshToken() + .finally(() => { + this.refreshPromise = null; // Clear when done (success or failure) }); - } else { - this.updating_already = true; - return this.refreshToken() - .then((result) => { - this.updating_already = false; - return callback(); - }) - .catch(() => { - this.updating_already = false; - return error(); - }); + + try { + await this.refreshPromise; + return callback(); + } catch (error) { + return errorCallback(); } } @@ -992,28 +995,32 @@ export default class ApiService { return fetch(url, options).then(this.errorHandler); } - fetch(url, options) { + async fetch(url, options) { const token = this.mode !== "mturk" ? this.getToken() : null; if ( !!token && this.isTokenExpired(token) && - url !== `${this.domain}/authenticate` + url !== `${this.alternateDomain}/login` ) { - return this.refreshTokenWrapper( - (res) => { - //console.log("Our token was refreshed (fetch callback)"); - return this.doFetch(url, options, {}, true); - }, - (res) => { - console.log("Could not refresh token (fetch)"); - var error = new Error("Could not refresh token"); - localStorage.removeItem("id_token"); - //window.location.href = '/login'; - throw error; - } - ); + try { + await this.refreshTokenWrapper( + () => { + //console.log("Our token was refreshed (fetch callback)"); + return Promise.resolve(); + }, + () => { + localStorage.removeItem("id_token"); + //window.location.href = '/login'; + throw new Error("Could not refresh token"); + } + ); + + return this.doFetch(url, options, true); + } catch (error) { + throw error; + } } - return this.doFetch(url, options, {}, true); + return this.doFetch(url, options, true); } errorHandler(response) { diff --git a/frontends/web/src/containers/App.js b/frontends/web/src/containers/App.js index e01bef207..a0ccaefbe 100644 --- a/frontends/web/src/containers/App.js +++ b/frontends/web/src/containers/App.js @@ -180,6 +180,7 @@ class App extends React.Component { }, credentials: "include", mode: "cors", + cache: "no-cache", }} > diff --git a/frontends/web/src/containers/ModelPage.js b/frontends/web/src/containers/ModelPage.js index abb8f1de8..75ebbb92b 100644 --- a/frontends/web/src/containers/ModelPage.js +++ b/frontends/web/src/containers/ModelPage.js @@ -407,7 +407,7 @@ class ModelPage extends React.Component { this.context.api.getAdminOrOwner(this.state.model.tid).then( (adminOrOwnerResult) => { this.setState({ - isAdminOrTaskOwner: adminOrOwnerResult.admin_or_owner, + isAdminOrTaskOwner: adminOrOwnerResult, }); }, (error) => { diff --git a/frontends/web/src/containers/TaskOwnerPage.js b/frontends/web/src/containers/TaskOwnerPage.js index 98cf8e77f..484d0c2e0 100644 --- a/frontends/web/src/containers/TaskOwnerPage.js +++ b/frontends/web/src/containers/TaskOwnerPage.js @@ -97,7 +97,7 @@ class TaskOwnerPage extends React.Component { (adminOrOwnerResult) => { this.setState( { - admin_or_owner: adminOrOwnerResult.admin_or_owner, + admin_or_owner: adminOrOwnerResult, task: result, }, callback diff --git a/frontends/web/src/containers/TaskPage.js b/frontends/web/src/containers/TaskPage.js index 29886f1de..8862d6a8d 100644 --- a/frontends/web/src/containers/TaskPage.js +++ b/frontends/web/src/containers/TaskPage.js @@ -77,7 +77,7 @@ class TaskPage extends React.Component { this.context.api.getAdminOrOwner(result.id).then( (adminOrOwnerResult) => { this.setState({ - admin_or_owner: adminOrOwnerResult.admin_or_owner, + admin_or_owner: adminOrOwnerResult, }); }, (error) => { diff --git a/frontends/web/src/new_front/pages/ProfilePage/ProfilePage.tsx b/frontends/web/src/new_front/pages/ProfilePage/ProfilePage.tsx index 16539ee13..b6bd74507 100644 --- a/frontends/web/src/new_front/pages/ProfilePage/ProfilePage.tsx +++ b/frontends/web/src/new_front/pages/ProfilePage/ProfilePage.tsx @@ -25,41 +25,56 @@ type Props = { const ProfilePage: FC = () => { const [userInfo, setUserInfo] = useState({} as UserInfoProps); const [userStats, setUserStats] = useState( - {} as UserStatsProps, + {} as UserStatsProps ); const [modelsInfo, setModelsInfo] = useState( - [] as ModelsInfo[], + [] as ModelsInfo[] ); const [tasksCategories, setTasksCategories] = useState([]); const [tasksInfo, setTasksInfo] = useState( - [] as TaskInfoType[], + [] as TaskInfoType[] ); - const { user } = useContext(UserContext); - const { get, response, loading } = useFetch(); + const [loading, setLoading] = useState(true); + const { user, api } = useContext(UserContext); const history = useHistory(); - const userId = user.id; + const getUserInfo = async () => { if (!userId) { return; } - const [userInfo, tasksInfo, modelsInfo, userStats] = await Promise.all([ - get(`/user/get_user_with_badges/${userId}`), - get(`/task/get_active_tasks_by_user_id/${userId}`), - get(`/model/get_models_by_user_id/${userId}`), - get(`/user/get_stats_by_user_id/${userId}`), - ]); - if (response.ok) { + try { + const backendUrl = process.env.REACT_APP_API_HOST_2; + + const [userInfo, tasksInfo, modelsInfo, userStats] = await Promise.all([ + api.fetch(`${backendUrl}/user/get_user_with_badges/${userId}`, { + method: "GET", + }), + api.fetch(`${backendUrl}/task/get_active_tasks_by_user_id/${userId}`, { + method: "GET", + }), + api.fetch(`${backendUrl}/model/get_models_by_user_id/${userId}`, { + method: "GET", + }), + api.fetch(`${backendUrl}/user/get_stats_by_user_id/${userId}`, { + method: "GET", + }), + ]); setUserInfo(userInfo); setTasksInfo(tasksInfo); setModelsInfo(modelsInfo); setUserStats(userStats); - } else { - Swal.fire({ - icon: "error", - title: "Oops...", - text: "Something went wrong!", - }); + setLoading(false); + } catch (error) { + if (error.status === 401) { + history.push("/login"); + } else { + Swal.fire({ + icon: "error", + title: "Oops...", + text: "Something went wrong!", + }); + } } }; diff --git a/frontends/web/src/new_front/pages/Task/TaskPage.tsx b/frontends/web/src/new_front/pages/Task/TaskPage.tsx index 871532c04..1997ded23 100644 --- a/frontends/web/src/new_front/pages/Task/TaskPage.tsx +++ b/frontends/web/src/new_front/pages/Task/TaskPage.tsx @@ -60,28 +60,22 @@ const TaskPage = () => { history.push("/PageNotFound"); } }, - [taskCode], + [taskCode] ); - const checkAdminOrOwner = useCallback( - async (user_id: number) => { - if (user?.id && task?.id) { - const adminOrOwner = await post("/auth/is_admin_or_owner", { - task_id: task?.id, - user_id: user_id, - }); - if (response.ok) { - setAdminOrOwner(adminOrOwner); - } - } else { - setAdminOrOwner(false); + const checkAdminOrOwner = useCallback(async () => { + if (user?.id && task?.id) { + const adminOrOwner = await get(`/auth/is_admin_or_owner/${task?.id}`); + if (response.ok) { + setAdminOrOwner(adminOrOwner); } - }, - [user?.id, task], - ); + } else { + setAdminOrOwner(false); + } + }, [user?.id, task]); useEffect(() => { - user?.id && task && checkAdminOrOwner(user?.id); + user?.id && task && checkAdminOrOwner(); }, [user, task]); useEffect(() => { @@ -181,10 +175,10 @@ const TaskPage = () => { { showLeaderboard={Boolean(task.show_leaderboard)} showTrends={Boolean(task.show_trends)} showUserLeaderboard={Boolean( - task.show_user_leaderboard, + task.show_user_leaderboard )} showUserLeaderboardCSV={Boolean( - task.show_user_leaderboard_csv, + task.show_user_leaderboard_csv )} /> ) : ( @@ -239,10 +233,10 @@ const TaskPage = () => { showLeaderboard={Boolean(task.show_leaderboard)} showTrends={Boolean(task.show_trends)} showUserLeaderboard={Boolean( - task.show_user_leaderboard, + task.show_user_leaderboard )} showUserLeaderboardCSV={Boolean( - task.show_user_leaderboard_csv, + task.show_user_leaderboard_csv )} /> )} From 54be2f2ba48d3ba80698483feb62d48df87a57b7 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 7 Nov 2025 10:27:41 -0500 Subject: [PATCH 09/30] is_admin_or_owner touches --- backend/app/api/endpoints/auth.py | 13 ++++--------- backend/app/domain/auth/authentication.py | 14 +++++++++----- backend/app/domain/schemas/auth/auth.py | 5 ----- backend/app/main.py | 4 +++- .../pages/ProfilePage/ExamplesCreated.tsx | 3 +++ .../web/src/new_front/pages/Task/LeaderBoard.jsx | 2 +- 6 files changed, 20 insertions(+), 21 deletions(-) diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index f0e3f7cfa..38dc59c43 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -6,12 +6,7 @@ from app.api.middleware.authentication import validate_access_token from app.domain.auth.authentication import LoginService -from app.domain.schemas.auth.auth import ( - CreateUserRequest, - IsAdminOrOwnerRequest, - LoginRequest, - LoginResponse, -) +from app.domain.schemas.auth.auth import CreateUserRequest, LoginRequest, LoginResponse router = APIRouter() @@ -22,11 +17,11 @@ async def create_user(model: CreateUserRequest): return LoginService().create_user(model.email, model.password, model.username) -@router.post("/is_admin_or_owner", response_model=bool) +@router.get("/is_admin_or_owner/{task_id}", response_model=bool) async def is_admin_or_owner( - model: IsAdminOrOwnerRequest, token_payload=Depends(validate_access_token) + task_id: int, request: Request, token_payload=Depends(validate_access_token) ): - return LoginService().is_admin_or_owner(model.user_id, model.task_id) + return LoginService().is_admin_or_owner(task_id, request) @router.get("/refresh") diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index 68309cfae..19782dd3b 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta, timezone from typing import Any, Union -from fastapi import HTTPException, status +from fastapi import HTTPException, Request, status from jose import jwt from werkzeug.security import check_password_hash, generate_password_hash @@ -90,10 +90,11 @@ def set_refresh_token(self, response, user_id: int) -> str: httponly=True, path="/", expires=cookie_expires, - # For localhost testing set secure to False - secure=False, - # For Localhost testing set samesite to None, else lax + # For localhost testing set secure to False in Prod to True + secure=True, samesite="lax", + # For localhost testing set domain to localhost + # domain="localhost" ) return refresh_token @@ -159,7 +160,10 @@ def logout(self, request, response) -> dict: else: refresh_token_expired() - def is_admin_or_owner(self, user_id: int, task_id: int): + def is_admin_or_owner(self, task_id: int, request: Request) -> bool: + user_id = request.state.user + if not user_id: + return False return self.task_user_permission_repository.is_task_owner( user_id, task_id ) or self.users_service.get_is_admin(user_id) diff --git a/backend/app/domain/schemas/auth/auth.py b/backend/app/domain/schemas/auth/auth.py index 5a0c90977..2c85b4fc1 100644 --- a/backend/app/domain/schemas/auth/auth.py +++ b/backend/app/domain/schemas/auth/auth.py @@ -29,8 +29,3 @@ class LoginResponse(BaseModel): class TokenPayload(BaseModel): access_token: str token_type: str - - -class IsAdminOrOwnerRequest(BaseModel): - user_id: int = 0 - task_id: int diff --git a/backend/app/main.py b/backend/app/main.py index acec997ee..6c4e1c90e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -35,9 +35,11 @@ origins = [ "http://localhost:3000", + "http://127.0.0.1:3000", "https://www.dynabench.org", + "https://dynabench.org", "https://front-dev.dynabench.org", - # "postman://app", include this only when testing with postman + # "postman://app", ] app.add_middleware( diff --git a/frontends/web/src/new_front/pages/ProfilePage/ExamplesCreated.tsx b/frontends/web/src/new_front/pages/ProfilePage/ExamplesCreated.tsx index 0650f2fba..0dfe2efe2 100644 --- a/frontends/web/src/new_front/pages/ProfilePage/ExamplesCreated.tsx +++ b/frontends/web/src/new_front/pages/ProfilePage/ExamplesCreated.tsx @@ -11,6 +11,9 @@ const ExamplesCreated = () => { const { user } = useContext(UserContext); const getTasks = async () => { + if (!user.id) { + return; + } const tasks = await get( `/task/get_tasks_with_samples_created_by_user/${user.id}` ); diff --git a/frontends/web/src/new_front/pages/Task/LeaderBoard.jsx b/frontends/web/src/new_front/pages/Task/LeaderBoard.jsx index bc3ddb2f1..290afdefa 100644 --- a/frontends/web/src/new_front/pages/Task/LeaderBoard.jsx +++ b/frontends/web/src/new_front/pages/Task/LeaderBoard.jsx @@ -63,7 +63,7 @@ class Leaderboard extends React.Component { this.context.api.getAdminOrOwner(result.id).then( (adminOrOwnerResult) => { this.setState({ - admin_or_owner: adminOrOwnerResult.admin_or_owner, + admin_or_owner: adminOrOwnerResult, }); }, (error) => { From 575fbe45ba0b0e91329e31b92e2986cafad2ed1f Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 10 Nov 2025 21:13:21 -0500 Subject: [PATCH 10/30] move endpoint users from bottle to fastAPI --- backend/app/api/endpoints/base/user.py | 9 +++++++- backend/app/domain/schemas/base/user.py | 13 +++++++++++ backend/app/domain/services/base/user.py | 22 +++++++++++++++++++ .../repositories/taskuserpermission.py | 8 +++++++ frontends/web/src/common/ApiService.js | 2 +- 5 files changed, 52 insertions(+), 2 deletions(-) diff --git a/backend/app/api/endpoints/base/user.py b/backend/app/api/endpoints/base/user.py index c9b624455..656f4ed11 100644 --- a/backend/app/api/endpoints/base/user.py +++ b/backend/app/api/endpoints/base/user.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Request -from app.domain.schemas.base.user import UserInfoBadges +from app.domain.schemas.base.user import UserInfoBadges, UserInfoBasic from app.domain.services.base.user import UserService @@ -23,3 +23,10 @@ async def get_stats_by_user_id(user_id: int, request: Request): if user_id != request.state.user: raise PermissionError("Unauthorized access to user data.") return UserService().get_stats_by_user_id(user_id) + + +@router.get("/{user_id}", response_model=UserInfoBasic) +async def get_user_by_id(user_id: int, request: Request): + if user_id != request.state.user: + raise PermissionError("Unauthorized access to user data.") + return UserService().get_user_basics_by_id(user_id) diff --git a/backend/app/domain/schemas/base/user.py b/backend/app/domain/schemas/base/user.py index 6441746b6..2e3aab8d9 100644 --- a/backend/app/domain/schemas/base/user.py +++ b/backend/app/domain/schemas/base/user.py @@ -27,3 +27,16 @@ class UserInfoBadges(BaseModel): examples_verified_incorrect_fooled: Optional[int] = None examples_fooled: Optional[int] = None badges: List[UserBadges] = [] + + +class taskPermission(BaseModel): + tid: int + type: str + + +class UserInfoBasic(BaseModel): + id: int + email: str + username: str + task_permissions: Optional[List[taskPermission]] = [] + admin: bool diff --git a/backend/app/domain/services/base/user.py b/backend/app/domain/services/base/user.py index 30b81912a..4166d5d96 100644 --- a/backend/app/domain/services/base/user.py +++ b/backend/app/domain/services/base/user.py @@ -4,6 +4,9 @@ from app.infrastructure.repositories.example import ExampleRepository from app.infrastructure.repositories.model import ModelRepository +from app.infrastructure.repositories.taskuserpermission import ( + TaskUserPermissionRepository, +) from app.infrastructure.repositories.user import UserRepository from app.infrastructure.repositories.validation import ValidationRepository @@ -14,6 +17,8 @@ def __init__(self): self.example_repository = ExampleRepository() self.validation_repository = ValidationRepository() self.model_repository = ModelRepository() + self.user_permissions_repository = UserRepository() + self.task_user_permissions_repository = TaskUserPermissionRepository() def increment_examples_fooled(self, user_id: int): self.user_repository.increment_examples_fooled(user_id) @@ -69,3 +74,20 @@ def get_stats_by_user_id(self, user_id: int): def download_users_info(self): return self.user_repository.download_users_info() + + def get_user_basics_by_id(self, user_id: int): + admin = self.user_repository.get_is_admin(user_id) + user_email = self.user_repository.get_user_email(user_id)[0] + username = self.user_repository.get_user_name_by_id(user_id)[0] + task_permissions = ( + self.task_user_permissions_repository.get_task_permissions_by_user_id( + user_id + ) + ) + return { + "admin": admin, + "email": user_email, + "username": username, + "id": user_id, + "task_permissions": task_permissions, + } diff --git a/backend/app/infrastructure/repositories/taskuserpermission.py b/backend/app/infrastructure/repositories/taskuserpermission.py index 6224f2608..abb56ddf6 100644 --- a/backend/app/infrastructure/repositories/taskuserpermission.py +++ b/backend/app/infrastructure/repositories/taskuserpermission.py @@ -22,3 +22,11 @@ def is_task_owner(self, user_id: int, task_id: int) -> bool: .filter(self.model.type == "owner") .first() ) is not None + + def get_task_permissions_by_user_id(self, user_id: int): + instances = ( + self.session.query(self.model).filter(self.model.uid == user_id).all() + ) + return [ + self.instance_converter.instance_to_dict(instance) for instance in instances + ] diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 5b2c80977..be637f920 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -283,7 +283,7 @@ export default class ApiService { } getUser(id, badges = false) { - var url = `${this.domain}/users/${id}`; + var url = `${this.alternateDomain}/user/${id}`; if (badges) { url += "/badges"; } From bcc0f13ddd0c5f0c9ff665e45ed659b54c94f7f1 Mon Sep 17 00:00:00 2001 From: Sara H Date: Tue, 11 Nov 2025 11:37:57 -0500 Subject: [PATCH 11/30] run prettier --- frontends/web/src/common/ApiService.js | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index be637f920..0e373bd26 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -952,10 +952,9 @@ export default class ApiService { } // Start new refresh - this.refreshPromise = this.refreshToken() - .finally(() => { - this.refreshPromise = null; // Clear when done (success or failure) - }); + this.refreshPromise = this.refreshToken().finally(() => { + this.refreshPromise = null; // Clear when done (success or failure) + }); try { await this.refreshPromise; From 2bdf559f298e619ef7bdada98a614896c2c0d2b7 Mon Sep 17 00:00:00 2001 From: Sara H Date: Wed, 12 Nov 2025 12:03:34 -0500 Subject: [PATCH 12/30] Fix buttons when not needed in help-med and ps-on-ai --- .../Contexts/ChatRandomWithInstructions.tsx | 23 ++++--- .../Contexts/ChatWithInstructions.tsx | 38 +++++++----- .../AnnotationInterfaces/Contexts/Chatbot.tsx | 50 +++++++++------- .../Contexts/EvaluateTextsGenerative.tsx | 60 ++++++++++++------- .../components/Inputs/EvaluateText.tsx | 4 +- .../createSamples/createSamples/utils.ts | 1 + 6 files changed, 109 insertions(+), 67 deletions(-) diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx index 19f9a809d..2ed9d7e8e 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx @@ -53,10 +53,10 @@ const ChatRandomWithInstructions: FC< >([]); const [finishConversation, setFinishConversation] = useState(false); const [readInstructions, setReadInstructions] = useState( - artifactsInput?.jump_instructions ? true : false, + artifactsInput?.jump_instructions ? true : false ); const { updateModelInputs, modelInputs, cleanModelInputs } = useContext( - CreateInterfaceContext, + CreateInterfaceContext ); const { get, post, response, loading } = useFetch(); const { user } = useContext(UserContext); @@ -71,7 +71,7 @@ const ChatRandomWithInstructions: FC< user_id: user.id, round_id: realRoundId, url: artifactsInput?.redirect_url || null, - }, + } ); if (response.ok) { if (redirectUrl) { @@ -116,7 +116,7 @@ const ChatRandomWithInstructions: FC< { user_id: user.id, task_id: taskId, - }, + } ); if (response.ok) { setCallLoading(false); @@ -145,7 +145,7 @@ const ChatRandomWithInstructions: FC< const handlePreliminaryQuestionsSubmit = async () => { const requiredFields = preliminaryQuestions.map( - (question) => question?.field_name_for_the_model, + (question) => question?.field_name_for_the_model ); const allAnswered = requiredFields.every( @@ -153,7 +153,7 @@ const ChatRandomWithInstructions: FC< field in modelInputs && modelInputs[field] !== null && modelInputs[field] !== "" && - modelInputs[field] !== undefined, + modelInputs[field] !== undefined ); if (!allAnswered) { Swal.fire({ @@ -190,10 +190,10 @@ const ChatRandomWithInstructions: FC< try { const [contextResponse, modelResponse] = await Promise.all([ get( - `/context/get_distinct_context?user_id=${user.id}&round_id=${realRoundId}`, + `/context/get_distinct_context?user_id=${user.id}&round_id=${realRoundId}` ), get( - `/task/get_random_provider_and_model_info?task_id=${taskId}&user_id=${user.id}`, + `/task/get_random_provider_and_model_info?task_id=${taskId}&user_id=${user.id}` ), ]); if (response.ok) { @@ -308,7 +308,7 @@ const ChatRandomWithInstructions: FC<
{parse( - generative_context.artifacts.first_explainatory_block, + generative_context.artifacts.first_explainatory_block )} @@ -355,6 +355,11 @@ const ChatRandomWithInstructions: FC< updateModelInputs={updateModelInputs} setIsGenerativeContext={setIsGenerativeContext} allowPaste={artifactsInput?.allow_paste} + rateAtTheEnd={ + "rate_at_the_end" in artifactsInput + ? artifactsInput.rate_at_the_end + : true + } />
diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx index 0ab7ec8ad..2e6b1fd84 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx @@ -244,11 +244,13 @@ const ChatWithInstructions: FC< about how best to respond:

- 1) What healthcare service do you need? (e.g. A&E or routine GP follow-up) + 1) What healthcare service do you need? (e.g. A&E or + routine GP follow-up)

- 2) Why did you make the choice you did? Please name all of the - specific medical conditions you consider relevant to your decision. (e.g. suspected broken bone) + 2) Why did you make the choice you did? Please name all + of the specific medical conditions you consider relevant + to your decision. (e.g. suspected broken bone)

The scenario (available below and on the next page) @@ -264,14 +266,15 @@ const ChatWithInstructions: FC< {treatmentValue !== "control" ? ( <>

- To assist in completing the scenarios, please use the - language model provided. We are interested in - understanding how you use the language model provided - and how well it works for you. Therefore, it is - essential that you{" "}only use your own - words, and do not copy and paste from the - scenario text, or from any other source. Please do not - use additional external sources. + To assist in completing the scenarios, please use + the language model provided. We are interested in + understanding how you use the language model + provided and how well it works for you. Therefore, + it is essential that you{" "} + only use your own words, and do not + copy and paste from the scenario text, or from any + other source. Please do not use additional external + sources.

) : ( @@ -325,7 +328,9 @@ const ChatWithInstructions: FC<
)}
-

Scenario

+

+ Scenario +

@@ -360,6 +365,11 @@ const ChatWithInstructions: FC< setFinishConversation={setFinishConversation} updateModelInputs={updateModelInputs} setIsGenerativeContext={setIsGenerativeContext} + rateAtTheEnd={ + artifactsInput.rate_at_the_end + ? artifactsInput.rate_at_the_end + : true + } /> ) : ( <> @@ -378,8 +388,8 @@ const ChatWithInstructions: FC<

Keep track of the methods you are using in the - textbox below. The questions will appear at the - bottom of the page after you have finished putting + textbox below. The questions will appear at the + bottom of the page after you have finished putting in your approach.

diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx index fbde72104..904934aa3 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx @@ -22,6 +22,7 @@ const Chatbot: FC = ({ numInteractionsChatbot, finishConversation, optionsSlider, + rateAtTheEnd, setChatHistory, showOriginalInteractions, setFinishConversation, @@ -495,30 +496,35 @@ const Chatbot: FC = ({ texts={newResponses} setTexts={setNewResponses} optionsSlider={optionsSlider} + rateAtTheEnd={rateAtTheEnd} score={response.score} handleWhenButtons={handleOnClick} /> ))} - {newResponses.length > 0 && !optionsSlider && ( -
- handleOnClick("tie")} - text={"It's a tie 🤝"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "tie"} - /> -
- )} - {newResponses.length > 0 && !optionsSlider && ( -
- handleOnClick("all_bad")} - text={"All are bad 🚫"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "all_bad"} - /> -
- )} + {newResponses.length > 0 && + !optionsSlider && + !rateAtTheEnd && ( +
+ handleOnClick("tie")} + text={"It's a tie 🤝 1"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "tie"} + /> +
+ )} + {newResponses.length > 0 && + !optionsSlider && + !rateAtTheEnd && ( +
+ handleOnClick("all_bad")} + text={"All are bad 🚫"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "all_bad"} + /> +
+ )} {showReasonModal && ( = ({ onClick={saveHistoryValidation} text="Save" className="px-4 py-1 font-semibold border-0 font-weight-bold light-gray-bg task-action-btn " - disabled={!currentSelection && !optionsSlider} + disabled={ + !currentSelection && !optionsSlider && !rateAtTheEnd + } /> )} diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx index f586db786..d4d4b5cb5 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx @@ -551,6 +551,11 @@ const EvaluateTextsGenerative: FC< id={text.id} texts={texts} setTexts={setTexts} + rateAtTheEnd={ + "rate_at_the_end" in artifactsInput + ? artifactsInput?.rate_at_the_end + : true + } optionsSlider={artifactsInput.options_slider} disabled={finishConversation} bestAnswer={bestAnswer.text} @@ -558,28 +563,34 @@ const EvaluateTextsGenerative: FC< handleWhenButtons={handleOnClick} /> ))} - {!artifactsInput?.options_slider && ( -
- handleOnClick("tie")} - text={"It's a tie 🤝"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "tie"} - disabled={finishConversation} - /> -
- )} - {!artifactsInput?.options_slider && ( -
- handleOnClick("all_bad")} - text={"All are bad 🚫"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "all_bad"} - disabled={finishConversation} - /> -
- )} + {!artifactsInput?.options_slider && + !("rate_at_the_end" in artifactsInput + ? artifactsInput?.rate_at_the_end + : true) && ( +
+ handleOnClick("tie")} + text={"It's a tie 🤝"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "tie"} + disabled={finishConversation} + /> +
+ )} + {!artifactsInput?.options_slider && + !("rate_at_the_end" in artifactsInput + ? artifactsInput?.rate_at_the_end + : true) && ( +
+ handleOnClick("all_bad")} + text={"All are bad 🚫"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "all_bad"} + disabled={finishConversation} + /> +
+ )} )} {!finishConversation && @@ -620,6 +631,11 @@ const EvaluateTextsGenerative: FC< "choose_when_tie" in artifactsInput && artifactsInput?.choose_when_tie } + rateAtTheEnd={ + "rate_at_the_end" in artifactsInput + ? artifactsInput?.rate_at_the_end + : true + } showChosenHistory={ "show_chosen_history" in artifactsInput && artifactsInput?.show_chosen_history diff --git a/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx b/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx index 49a5397e6..a116b5d5d 100644 --- a/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx +++ b/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx @@ -12,6 +12,7 @@ type EvaluateTextProps = { bestAnswer?: string; score?: number; handleWhenButtons?: any; + rateAtTheEnd?: boolean; }; const EvaluateText: FC = ({ @@ -24,6 +25,7 @@ const EvaluateText: FC = ({ bestAnswer, score = 50, handleWhenButtons, + rateAtTheEnd, }) => { const handleUpdateScore = (event: any) => { setTexts( @@ -93,7 +95,7 @@ const EvaluateText: FC = ({ )} - {!optionsSlider && ( + {!optionsSlider && !rateAtTheEnd && (
Option # {id + 1} Date: Wed, 12 Nov 2025 12:09:10 -0500 Subject: [PATCH 13/30] undo last commit --- .../Contexts/ChatRandomWithInstructions.tsx | 23 +++---- .../Contexts/ChatWithInstructions.tsx | 38 +++++------- .../AnnotationInterfaces/Contexts/Chatbot.tsx | 50 +++++++--------- .../Contexts/EvaluateTextsGenerative.tsx | 60 +++++++------------ .../components/Inputs/EvaluateText.tsx | 4 +- .../createSamples/createSamples/utils.ts | 1 - 6 files changed, 67 insertions(+), 109 deletions(-) diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx index 2ed9d7e8e..19f9a809d 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatRandomWithInstructions.tsx @@ -53,10 +53,10 @@ const ChatRandomWithInstructions: FC< >([]); const [finishConversation, setFinishConversation] = useState(false); const [readInstructions, setReadInstructions] = useState( - artifactsInput?.jump_instructions ? true : false + artifactsInput?.jump_instructions ? true : false, ); const { updateModelInputs, modelInputs, cleanModelInputs } = useContext( - CreateInterfaceContext + CreateInterfaceContext, ); const { get, post, response, loading } = useFetch(); const { user } = useContext(UserContext); @@ -71,7 +71,7 @@ const ChatRandomWithInstructions: FC< user_id: user.id, round_id: realRoundId, url: artifactsInput?.redirect_url || null, - } + }, ); if (response.ok) { if (redirectUrl) { @@ -116,7 +116,7 @@ const ChatRandomWithInstructions: FC< { user_id: user.id, task_id: taskId, - } + }, ); if (response.ok) { setCallLoading(false); @@ -145,7 +145,7 @@ const ChatRandomWithInstructions: FC< const handlePreliminaryQuestionsSubmit = async () => { const requiredFields = preliminaryQuestions.map( - (question) => question?.field_name_for_the_model + (question) => question?.field_name_for_the_model, ); const allAnswered = requiredFields.every( @@ -153,7 +153,7 @@ const ChatRandomWithInstructions: FC< field in modelInputs && modelInputs[field] !== null && modelInputs[field] !== "" && - modelInputs[field] !== undefined + modelInputs[field] !== undefined, ); if (!allAnswered) { Swal.fire({ @@ -190,10 +190,10 @@ const ChatRandomWithInstructions: FC< try { const [contextResponse, modelResponse] = await Promise.all([ get( - `/context/get_distinct_context?user_id=${user.id}&round_id=${realRoundId}` + `/context/get_distinct_context?user_id=${user.id}&round_id=${realRoundId}`, ), get( - `/task/get_random_provider_and_model_info?task_id=${taskId}&user_id=${user.id}` + `/task/get_random_provider_and_model_info?task_id=${taskId}&user_id=${user.id}`, ), ]); if (response.ok) { @@ -308,7 +308,7 @@ const ChatRandomWithInstructions: FC<
{parse( - generative_context.artifacts.first_explainatory_block + generative_context.artifacts.first_explainatory_block, )}
@@ -355,11 +355,6 @@ const ChatRandomWithInstructions: FC< updateModelInputs={updateModelInputs} setIsGenerativeContext={setIsGenerativeContext} allowPaste={artifactsInput?.allow_paste} - rateAtTheEnd={ - "rate_at_the_end" in artifactsInput - ? artifactsInput.rate_at_the_end - : true - } />
diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx index 2e6b1fd84..0ab7ec8ad 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/ChatWithInstructions.tsx @@ -244,13 +244,11 @@ const ChatWithInstructions: FC< about how best to respond:

- 1) What healthcare service do you need? (e.g. A&E or - routine GP follow-up) + 1) What healthcare service do you need? (e.g. A&E or routine GP follow-up)

- 2) Why did you make the choice you did? Please name all - of the specific medical conditions you consider relevant - to your decision. (e.g. suspected broken bone) + 2) Why did you make the choice you did? Please name all of the + specific medical conditions you consider relevant to your decision. (e.g. suspected broken bone)

The scenario (available below and on the next page) @@ -266,15 +264,14 @@ const ChatWithInstructions: FC< {treatmentValue !== "control" ? ( <>

- To assist in completing the scenarios, please use - the language model provided. We are interested in - understanding how you use the language model - provided and how well it works for you. Therefore, - it is essential that you{" "} - only use your own words, and do not - copy and paste from the scenario text, or from any - other source. Please do not use additional external - sources. + To assist in completing the scenarios, please use the + language model provided. We are interested in + understanding how you use the language model provided + and how well it works for you. Therefore, it is + essential that you{" "}only use your own + words, and do not copy and paste from the + scenario text, or from any other source. Please do not + use additional external sources.

) : ( @@ -328,9 +325,7 @@ const ChatWithInstructions: FC<
)}
-

- Scenario -

+

Scenario

@@ -365,11 +360,6 @@ const ChatWithInstructions: FC< setFinishConversation={setFinishConversation} updateModelInputs={updateModelInputs} setIsGenerativeContext={setIsGenerativeContext} - rateAtTheEnd={ - artifactsInput.rate_at_the_end - ? artifactsInput.rate_at_the_end - : true - } /> ) : ( <> @@ -388,8 +378,8 @@ const ChatWithInstructions: FC<

Keep track of the methods you are using in the - textbox below. The questions will appear at the - bottom of the page after you have finished putting + textbox below. The questions will appear at the + bottom of the page after you have finished putting in your approach.

diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx index 904934aa3..fbde72104 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/Chatbot.tsx @@ -22,7 +22,6 @@ const Chatbot: FC = ({ numInteractionsChatbot, finishConversation, optionsSlider, - rateAtTheEnd, setChatHistory, showOriginalInteractions, setFinishConversation, @@ -496,35 +495,30 @@ const Chatbot: FC = ({ texts={newResponses} setTexts={setNewResponses} optionsSlider={optionsSlider} - rateAtTheEnd={rateAtTheEnd} score={response.score} handleWhenButtons={handleOnClick} /> ))} - {newResponses.length > 0 && - !optionsSlider && - !rateAtTheEnd && ( -
- handleOnClick("tie")} - text={"It's a tie 🤝 1"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "tie"} - /> -
- )} - {newResponses.length > 0 && - !optionsSlider && - !rateAtTheEnd && ( -
- handleOnClick("all_bad")} - text={"All are bad 🚫"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "all_bad"} - /> -
- )} + {newResponses.length > 0 && !optionsSlider && ( +
+ handleOnClick("tie")} + text={"It's a tie 🤝"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "tie"} + /> +
+ )} + {newResponses.length > 0 && !optionsSlider && ( +
+ handleOnClick("all_bad")} + text={"All are bad 🚫"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "all_bad"} + /> +
+ )} {showReasonModal && ( = ({ onClick={saveHistoryValidation} text="Save" className="px-4 py-1 font-semibold border-0 font-weight-bold light-gray-bg task-action-btn " - disabled={ - !currentSelection && !optionsSlider && !rateAtTheEnd - } + disabled={!currentSelection && !optionsSlider} /> )} diff --git a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx index d4d4b5cb5..f586db786 100644 --- a/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx +++ b/frontends/web/src/new_front/components/CreateSamples/CreateSamples/AnnotationInterfaces/Contexts/EvaluateTextsGenerative.tsx @@ -551,11 +551,6 @@ const EvaluateTextsGenerative: FC< id={text.id} texts={texts} setTexts={setTexts} - rateAtTheEnd={ - "rate_at_the_end" in artifactsInput - ? artifactsInput?.rate_at_the_end - : true - } optionsSlider={artifactsInput.options_slider} disabled={finishConversation} bestAnswer={bestAnswer.text} @@ -563,34 +558,28 @@ const EvaluateTextsGenerative: FC< handleWhenButtons={handleOnClick} /> ))} - {!artifactsInput?.options_slider && - !("rate_at_the_end" in artifactsInput - ? artifactsInput?.rate_at_the_end - : true) && ( -
- handleOnClick("tie")} - text={"It's a tie 🤝"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "tie"} - disabled={finishConversation} - /> -
- )} - {!artifactsInput?.options_slider && - !("rate_at_the_end" in artifactsInput - ? artifactsInput?.rate_at_the_end - : true) && ( -
- handleOnClick("all_bad")} - text={"All are bad 🚫"} - className="border-0 font-weight-bold light-gray-bg task-action-btn" - active={currentSelection === "all_bad"} - disabled={finishConversation} - /> -
- )} + {!artifactsInput?.options_slider && ( +
+ handleOnClick("tie")} + text={"It's a tie 🤝"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "tie"} + disabled={finishConversation} + /> +
+ )} + {!artifactsInput?.options_slider && ( +
+ handleOnClick("all_bad")} + text={"All are bad 🚫"} + className="border-0 font-weight-bold light-gray-bg task-action-btn" + active={currentSelection === "all_bad"} + disabled={finishConversation} + /> +
+ )} )} {!finishConversation && @@ -631,11 +620,6 @@ const EvaluateTextsGenerative: FC< "choose_when_tie" in artifactsInput && artifactsInput?.choose_when_tie } - rateAtTheEnd={ - "rate_at_the_end" in artifactsInput - ? artifactsInput?.rate_at_the_end - : true - } showChosenHistory={ "show_chosen_history" in artifactsInput && artifactsInput?.show_chosen_history diff --git a/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx b/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx index a116b5d5d..49a5397e6 100644 --- a/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx +++ b/frontends/web/src/new_front/components/Inputs/EvaluateText.tsx @@ -12,7 +12,6 @@ type EvaluateTextProps = { bestAnswer?: string; score?: number; handleWhenButtons?: any; - rateAtTheEnd?: boolean; }; const EvaluateText: FC = ({ @@ -25,7 +24,6 @@ const EvaluateText: FC = ({ bestAnswer, score = 50, handleWhenButtons, - rateAtTheEnd, }) => { const handleUpdateScore = (event: any) => { setTexts( @@ -95,7 +93,7 @@ const EvaluateText: FC = ({ )} - {!optionsSlider && !rateAtTheEnd && ( + {!optionsSlider && (
Option # {id + 1} Date: Sun, 16 Nov 2025 20:49:25 -0500 Subject: [PATCH 14/30] create Tasks endpoint --- backend/app/api/endpoints/base/task.py | 5 +++++ backend/app/domain/services/base/task.py | 20 +++++++++++++++++++ .../app/infrastructure/repositories/task.py | 11 ++++++++++ 3 files changed, 36 insertions(+) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index 3e405a5a4..00ec4d691 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -150,3 +150,8 @@ def get_random_provider_and_model_info(task_id: int, user_id: int): @router.get("/get_task_consent", response_model={}) def get_task_consent(task_id: int): return TaskService().get_task_consent(task_id) + + +@router.get("/all_active") +def get_tasks(): + return TaskService().get_tasks() diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index 21ae99d10..8626b3167 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -372,3 +372,23 @@ def get_task_consent(self, task_id: int): consent_file = json.loads(consent_file["Body"].read()) return consent_file + + def get_tasks(self, exclude_hidden: bool = True): + filters = { + "active": int(exclude_hidden), + } + tasks = self.task_repository.get_tasks_current_round(filters) + converted_tasks = [] + flag = 0 + for task, round_obj in tasks: + task_dict = { + **{k: v for k, v in task.__dict__.items() if not k.startswith("_")}, + "round": { + k: v for k, v in round_obj.__dict__.items() if not k.startswith("_") + }, + } + converted_tasks.append(task_dict) + if flag < 1: + print("task_dict", task_dict) + flag += 1 + return converted_tasks diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index 50861294d..43a34e6a1 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -187,3 +187,14 @@ def get_config_file_by_task_id(self, task_id: int): .filter(self.model.id == task_id) .first() ) + + def get_tasks_current_round(self, filters: dict = {}): + query = self.session.query(self.model, Round).join( + Round, (Round.tid == self.model.id) & (Round.rid == self.model.cur_round) + ) + for column_name, value in filters.items(): + if hasattr(self.model, column_name): + column = getattr(self.model, column_name) + query = query.filter(column == value) + instances = query.all() + return self.instance_converter.instance_to_dict(instances) From 996c65584a8b832f0ba95825bf7527f205bdeff9 Mon Sep 17 00:00:00 2001 From: Sara H Date: Sun, 16 Nov 2025 21:05:41 -0500 Subject: [PATCH 15/30] implement all_active tasks endpoint --- frontends/web/src/common/ApiService.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 0e373bd26..10386c6cf 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -153,7 +153,7 @@ export default class ApiService { } getTasks() { - return this.fetch(`${this.domain}/tasks`, { + return this.fetch(`${this.alternateDomain}/task/all_active`, { method: "GET", }); } @@ -449,7 +449,7 @@ export default class ApiService { } exportData(tid, rid = null) { - var export_link = `${this.domain}/tasks/${tid}`; + var export_link = `${this.alternateDomain}/tasks/${tid}`; if (rid !== null) { export_link += `/rounds/${rid}`; } From 76dcde430e695b2c82e82944250217a64d234349 Mon Sep 17 00:00:00 2001 From: Sara H Date: Wed, 19 Nov 2025 11:40:07 -0500 Subject: [PATCH 16/30] Move get task metadata from Bottle to FastAPI --- backend/app/api/endpoints/base/task.py | 5 ++ backend/app/domain/services/base/task.py | 88 +++++++++++++++++++ .../eval_utils/metrics_dicts.py | 1 - .../app/infrastructure/repositories/task.py | 11 +++ frontends/web/src/common/ApiService.js | 2 +- 5 files changed, 105 insertions(+), 2 deletions(-) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index 00ec4d691..e51e3ebc0 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -155,3 +155,8 @@ def get_task_consent(task_id: int): @router.get("/all_active") def get_tasks(): return TaskService().get_tasks() + + +@router.get("/round_and_metric_data/{task_code}", response_model={}) +def get_task_with_round_and_metric_data(task_code: str): + return TaskService().get_task_with_round_and_metric_data(task_code) diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index 8626b3167..c74eb5d0a 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -392,3 +392,91 @@ def get_tasks(self, exclude_hidden: bool = True): print("task_dict", task_dict) flag += 1 return converted_tasks + + def get_task_with_round_and_metric_data(self, task_code: str): + try: + task, round_obj = self.task_repository.get_task_with_round_info(task_code) + + datasets = self.dataset_repository.get_order_datasets_by_task_id(task.id) + dataset_list = [] + scoring_dataset_list = [] + for dataset in datasets: + dataset_list.append({"id": dataset.id, "name": dataset.name}) + if dataset.access_type == "scoring": + scoring_dataset_list.append( + { + "id": dataset.id, + "name": dataset.name, + "default_weight": self.get_dataset_weight(dataset.id), + } + ) + dataset_list.sort(key=lambda dataset: dataset["id"]) + scoring_dataset_list.sort(key=lambda dataset: dataset["id"]) + + task_dict = task.__dict__ + round_dict = round_obj.__dict__ + task_dict["ordered_scoring_datasets"] = scoring_dataset_list + task_dict["ordered_datasets"] = dataset_list + config = yaml.load(task_dict["config_yaml"], yaml.SafeLoader) + if "perf_metric" in config: + if isinstance(config["perf_metric"], list): + principal_metric = config["perf_metric"][0] + task_dict["perf_metric_field_name"] = principal_metric["type"] + elif isinstance(config["perf_metric"], dict): + task_dict["perf_metric_field_name"] = config["perf_metric"]["type"] + metrics_meta, ordered_field_names = self.get_task_metrics_meta( + task_dict + ) + print("ordered_field_names", ordered_field_names) + ordered_metrics = [ + dict( + { + "name": metrics_meta[field_name]["pretty_name"], + "field_name": field_name, + "default_weight": self.get_metric_weight( + field_name, + task_dict.get("perf_metric_field_name"), + config.get("aggregation_metric", {}).get( + "default_weights", {-1: 1} + ), + ), + }, + **metrics_meta[field_name], + ) + for field_name in ordered_field_names + ] + + task_dict["ordered_metrics"] = ordered_metrics + task_dict["round"] = round_dict + return task_dict + except Exception as e: + return False + + def get_task_metrics_meta(self, task): + task_config = yaml.load(task["config_yaml"], yaml.SafeLoader) + if isinstance(task_config["perf_metric"], list): + perf_metric_type = [obj["type"] for obj in task_config["perf_metric"]] + elif isinstance(task_config["perf_metric"], dict): + perf_metric_type = [task_config["perf_metric"]["type"]] + delta_metric_types = [ + obj["type"] for obj in task_config.get("delta_metrics", []) + ] + aws_metric_names = instance_property.get(task.get("instance_type"), {}).get( + "aws_metrics", [] + ) + principal_metric = perf_metric_type[0] + + # TODO: make it possible to display some modes with aws metrics and some + # models without aws metrics on the same leaderboard? + if task.get("predictions_upload", False) or "train_file_metric" in task_config: + aws_metric_names = [] + ordered_metric_field_names = ( + perf_metric_type + aws_metric_names + delta_metric_types + ) + metrics_meta = { + metric: meta_metrics_dict.get(metric, meta_metrics_dict[principal_metric])( + task + ) + for metric in ordered_metric_field_names + } + return metrics_meta, ordered_metric_field_names diff --git a/backend/app/domain/services/builder_and_evaluation/eval_utils/metrics_dicts.py b/backend/app/domain/services/builder_and_evaluation/eval_utils/metrics_dicts.py index 95298aa25..666402d1f 100644 --- a/backend/app/domain/services/builder_and_evaluation/eval_utils/metrics_dicts.py +++ b/backend/app/domain/services/builder_and_evaluation/eval_utils/metrics_dicts.py @@ -86,7 +86,6 @@ meta_metrics_dict = { "accuracy": get_accuracy_meta, "new_accuracy": get_new_accuracy_meta, - "perf": get_accuracy_meta, "matthews_correlation": get_matthews_correlation_meta, "f1": get_f1_meta, "macro_f1": get_macro_f1_meta, diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index 43a34e6a1..a61609348 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -198,3 +198,14 @@ def get_tasks_current_round(self, filters: dict = {}): query = query.filter(column == value) instances = query.all() return self.instance_converter.instance_to_dict(instances) + + def get_task_with_round_info(self, task_code: str): + return ( + self.session.query(self.model, Round) + .filter(self.model.task_code == task_code) + .join( + Round, + (Round.tid == self.model.id) & (Round.rid == self.model.cur_round), + ) + .first() + ) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 10386c6cf..242b6cdf4 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -295,7 +295,7 @@ export default class ApiService { getTask(idOrCode) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( - `${this.domain}/tasks/${idOrCode}`, + `${this.alternateDomain}/task/round_and_metric_data/${idOrCode}`, { method: "GET", }, From 5c333261c8ae7decd9cc2fe8e7a2382986112a9e Mon Sep 17 00:00:00 2001 From: Sara H Date: Wed, 19 Nov 2025 14:06:49 -0500 Subject: [PATCH 17/30] unburn cc_contact variable --- backend/app/domain/services/base/task_proposals.py | 6 ++++-- .../domain/services/builder_and_evaluation/evaluation.py | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/backend/app/domain/services/base/task_proposals.py b/backend/app/domain/services/base/task_proposals.py index ef2d718c6..5c3d0b2c6 100644 --- a/backend/app/domain/services/base/task_proposals.py +++ b/backend/app/domain/services/base/task_proposals.py @@ -2,6 +2,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os import random from app.domain.helpers.email import EmailHelper @@ -12,6 +13,7 @@ class TaskProposalService: def __init__(self): + self.cc_contact = os.getenv("MAIL_LOGIN") self.task_proposal_repository = TaskProposalRepository() self.user_repository = UserRepository() self.task_repository = TaskRepository() @@ -32,14 +34,14 @@ def add_task_proposal(self, user_id: int, name: str, desc: str, longdesc: str): user_email = self.user_repository.get_user_email(user_id)[0] self.email_helper.send( contact=user_email, - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="task_proposal_update.txt", msg_dict={"name": name, "code": task_code, "desc": longdesc}, subject=f"Proposal for task {task_code}", ) self.email_helper.send( contact="juan.ciro@factored.ai", - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="task_proposal_update.txt", msg_dict={"name": name, "code": task_code, "desc": longdesc}, subject=f"Proposal for task {task_code}", diff --git a/backend/app/domain/services/builder_and_evaluation/evaluation.py b/backend/app/domain/services/builder_and_evaluation/evaluation.py index ffd1c0c49..06efd9ea4 100644 --- a/backend/app/domain/services/builder_and_evaluation/evaluation.py +++ b/backend/app/domain/services/builder_and_evaluation/evaluation.py @@ -53,6 +53,7 @@ def __init__(self): self.s3 = self.session.client("s3") self.cloud_watch = self.session.client("cloudwatch") self.s3_bucket = os.getenv("AWS_S3_BUCKET") + self.cc_contact = os.getenv("MAIL_LOGIN") self.builder = BuilderService() self.task_repository = TaskRepository() self.score_repository = ScoreRepository() @@ -449,7 +450,7 @@ def evaluation_with_selected_langs( user_email = self.user_repository.get_user_email(user_id)[0] self.email_helper.send( contact=user_email, - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="model_train_successful.txt", msg_dict={"name": model_name, "model_id": model_id}, subject=f"Model {model_name} training succeeded.", @@ -561,7 +562,7 @@ def evaluation( user_email = self.user_repository.get_user_email(user_id)[0] self.email_helper.send( contact=user_email, - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="model_train_successful.txt", msg_dict={"name": model_name, "model_id": model_id}, subject=f"Model {model_name} training succeeded.", @@ -572,7 +573,7 @@ def evaluation( user_email = self.user_repository.get_user_email(user_id)[0] self.email_helper.send( contact=user_email, - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="model_eval_failed.txt", msg_dict={"name": model_name, "model_id": model_id}, subject=f"Model {model_name} evaluation failed.", @@ -843,7 +844,7 @@ def evaluate_downstream_tasks(self, task_id: int, predictions: str, model_id: in self.score_repository.fix_f1_score(model_id) self.email_helper.send( contact=user_email, - cc_contact="dynabench-site@mlcommons.org", + cc_contact=self.cc_contact, template_name="model_train_successful.txt", msg_dict={"name": model_id, "model_id": model_id}, subject=f"Model {model_id} training succeeded.", From 41ffadd89a0c364885e3967b88ed0c5bc6599e20 Mon Sep 17 00:00:00 2001 From: Sara H Date: Thu, 20 Nov 2025 15:12:20 -0500 Subject: [PATCH 18/30] Move Trends from Bottle to FastAPI --- backend/app/api/endpoints/base/task.py | 5 + backend/app/domain/services/base/score.py | 218 ++++++++++++++++++ backend/app/domain/services/base/task.py | 126 +++++++++- .../app/infrastructure/repositories/score.py | 2 +- .../app/infrastructure/repositories/task.py | 20 +- frontends/web/src/common/ApiService.js | 72 +++--- 6 files changed, 393 insertions(+), 50 deletions(-) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index e51e3ebc0..395b7031b 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -160,3 +160,8 @@ def get_tasks(): @router.get("/round_and_metric_data/{task_code}", response_model={}) def get_task_with_round_and_metric_data(task_code: str): return TaskService().get_task_with_round_and_metric_data(task_code) + + +@router.get("/trends/{task_id}", response_model={}) +def get_task_trends(task_id: int): + return TaskService().get_task_trends(task_id) diff --git a/backend/app/domain/services/base/score.py b/backend/app/domain/services/base/score.py index 00e832fab..08171c144 100644 --- a/backend/app/domain/services/base/score.py +++ b/backend/app/domain/services/base/score.py @@ -3,6 +3,7 @@ # LICENSE file in the root directory of this source tree. import json +import math import os import boto3 @@ -484,3 +485,220 @@ def add_scores_and_update_model( return {"response": "Scores added successfully"} except Exception as e: return {"error": str(e)} + + def get_dynaboard_by_task( + self, + tid: int, + include_unpublished_models: bool, + perf_metric_field_name, + ordered_metrics_with_weight_and_conversion, + ordered_dids_with_weight, + sort_by="dynascore", + reverse_sort=False, + limit=5, + offset=0, + ): + ordered_dids = [ + did_and_weight["did"] for did_and_weight in ordered_dids_with_weight + ] + try: + scores_users_datasets_models = ( + self.score_repository.get_scores_users_dataset_and_model_by_task_id( + tid, + ordered_dids, + include_unpublished_models, + ) + ) + scores, users, datasets, models = zip(*scores_users_datasets_models) + + except Exception: + return ({"count": 0, "data": []},) + + scores, users, datasets, models = ( + set(scores), + set(users), + set(datasets), + set(models), + ) + + # Order datasets as in ordered_dids, for display purposes + ordered_datasets = [] + did_to_dataset = {} + for dataset in datasets: + did_to_dataset[dataset.id] = dataset + for dataset in datasets: + ordered_datasets.append(did_to_dataset[ordered_dids[len(ordered_datasets)]]) + datasets = ordered_datasets + + # Filter models and scores so that we have complete sets of scores. + # Unclear what the "null" values should be if we wanted to complete them. + mid_to_unique_dids = {} + all_unique_dids = set(ordered_dids) + for score in scores: + complete_score_for_dataset = True + score_metadata_dict = self._get_metadata_dict(score) + + for metric_info in ordered_metrics_with_weight_and_conversion: + if (score.__dict__.get(metric_info["field_name"], None) is None) and ( + score.metadata_json is None + or score_metadata_dict.get(metric_info["field_name"], None) is None + ): + complete_score_for_dataset = False + if complete_score_for_dataset: + if score.mid in mid_to_unique_dids: + mid_to_unique_dids[score.mid].add(score.did) + else: + mid_to_unique_dids[score.mid] = {score.did} + filtered_scores = [] + for score in scores: + if mid_to_unique_dids.get(score.mid, set()) == all_unique_dids: + filtered_scores.append(score) + scores = filtered_scores + filtered_models = [] + for model in models: + if mid_to_unique_dids.get(model.id, set()) == all_unique_dids: + filtered_models.append(model) + models = filtered_models + + mid_and_did_to_scores = {} + for score in scores: + mid_and_did_to_scores[(score.mid, score.did)] = score + dataset_results_dict = {} + for dataset in datasets: + dataset_results_dict[dataset.id] = { + metric_info["field_name"]: [] + for metric_info in ordered_metrics_with_weight_and_conversion + } + for model in models: + score = mid_and_did_to_scores[(model.id, dataset.id)] + for field_name in dataset_results_dict[dataset.id]: + result = score.__dict__.get(field_name, None) + if result is None: + metadata_dict = self._get_metadata_dict(score) + result = metadata_dict.get(field_name) + dataset_results_dict[dataset.id][field_name].append(result) + + # Average the results accross datasets. + averaged_dataset_results = None + did_to_weight = { + did_and_weight["did"]: did_and_weight["weight"] + for did_and_weight in ordered_dids_with_weight + } + for key, value in dataset_results_dict.items(): + df = pd.DataFrame.from_dict(value) + dataset_results_dict[key] = df + if averaged_dataset_results is None: + averaged_dataset_results = did_to_weight[key] * df + else: + averaged_dataset_results += did_to_weight[key] * df + + # Compute the dynascore. + converted_dataset_results = self.calculate_dynascore( + perf_metric_field_name, + averaged_dataset_results, + weights={ + metric_info["field_name"]: metric_info["weight"] + for metric_info in ordered_metrics_with_weight_and_conversion + }, + direction_multipliers={ + metric_info["field_name"]: metric_info["utility_direction"] + for metric_info in ordered_metrics_with_weight_and_conversion + }, + offsets={ + metric_info["field_name"]: metric_info["offset"] + for metric_info in ordered_metrics_with_weight_and_conversion + }, + ) + uid_to_username = {} + for user in users: + uid_to_username[user.id] = user.username + data_list = [] + model_index = 0 + ordered_metric_field_names = [ + metric_info["field_name"] + for metric_info in ordered_metrics_with_weight_and_conversion + ] + for model in models: + datasets_list = [] + for dataset in datasets: + scores = [] + for field_name in ordered_metric_field_names: + scores.append( + dataset_results_dict[dataset.id][field_name][model_index] + ) + variances = [0] * len(scores) # TODO + datasets_list.append( + { + "id": dataset.id, + "name": dataset.name, + "scores": scores, + "variances": variances, + } + ) + averaged_scores = [] + for field_name in ordered_metric_field_names: + averaged_scores.append( + averaged_dataset_results[field_name][model_index] + ) + averaged_variances = [0] * len(averaged_scores) # TODO + dynascore = converted_dataset_results["dynascore"][model_index] + data_list.append( + { + "model_id": model.id, + "model_name": model.name if model.is_published else None, + # Don't give away the users for unpublished models. + "uid": model.uid + if model.is_published and not model.is_anonymous + else None, + "username": uid_to_username[model.uid] + if model.is_published and not model.is_anonymous + else None, + "averaged_scores": averaged_scores, + "averaged_variances": averaged_variances, + "dynascore": dynascore + if not math.isnan(dynascore) + else 0, # It is possible for the dynascore to be nan if + # the leaderboard is uninteresting. For example, if, + # for any metric, all models on the leaderboard have that + # metric as 0. In these cases, dynascores for all models + # will be nan. + "dynavariance": 0, # TODO + "datasets": datasets_list, + } + ) + model_index += 1 + ordered_metric_pretty_names = [ + metric_info["pretty_name"] + for metric_info in ordered_metrics_with_weight_and_conversion + ] + if sort_by == "dynascore": + data_list.sort(reverse=reverse_sort, key=lambda model: model["dynascore"]) + elif sort_by in ordered_metric_pretty_names: + data_list.sort( + reverse=reverse_sort, + key=lambda model: model["averaged_scores"][ + ordered_metric_pretty_names.index(sort_by) + ], + ) + elif sort_by == "model_name": + data_list.sort(reverse=reverse_sort, key=lambda model: model["model_name"]) + + return { + "count": len(data_list), + "data": data_list[offset : offset + limit], + } + + def _get_metadata_dict(self, score): + """Safely parse metadata_json string to dictionary""" + if not score.metadata_json: + return {} + + try: + if isinstance(score.metadata_json, str): + return json.loads(score.metadata_json) + elif isinstance(score.metadata_json, dict): + return score.metadata_json + else: + return {} + except (json.JSONDecodeError, TypeError): + return {} diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index c74eb5d0a..5ca0587aa 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -6,6 +6,7 @@ import os import random from ast import literal_eval +from typing import Union import boto3 import yaml @@ -393,9 +394,11 @@ def get_tasks(self, exclude_hidden: bool = True): flag += 1 return converted_tasks - def get_task_with_round_and_metric_data(self, task_code: str): + def get_task_with_round_and_metric_data(self, task_id_or_code: Union[int, str]): try: - task, round_obj = self.task_repository.get_task_with_round_info(task_code) + task, round_obj = self.task_repository.get_task_with_round_info( + task_id_or_code + ) datasets = self.dataset_repository.get_order_datasets_by_task_id(task.id) dataset_list = [] @@ -449,7 +452,7 @@ def get_task_with_round_and_metric_data(self, task_code: str): task_dict["ordered_metrics"] = ordered_metrics task_dict["round"] = round_dict return task_dict - except Exception as e: + except Exception: return False def get_task_metrics_meta(self, task): @@ -480,3 +483,120 @@ def get_task_metrics_meta(self, task): for metric in ordered_metric_field_names } return metrics_meta, ordered_metric_field_names + + def get_task_trends(self, task_id: int): + """ + Get top perform models and its round wise performance metrics at task level + It will fetch only top 10 models and its round wise performance metrics + :param tid: Task id + :return: Json Object + """ + task_dict = self.get_task_with_round_and_metric_data(task_id) + ordered_metric_and_weight = list( + map( + lambda metric: dict({"weight": metric["default_weight"]}, **metric), + task_dict["ordered_metrics"], + ) + ) + ordered_did_and_weight = list( + map( + lambda dataset: dict( + {"weight": dataset["default_weight"], "did": dataset["id"]}, + **dataset, + ), + task_dict["ordered_scoring_datasets"], + ) + ) + dynaboard_response = self.score_services.get_dynaboard_by_task( + task_id, + task_dict.get("unpublished_models_in_leaderboard"), + task_dict.get("perf_metric_field_name"), + ordered_metric_and_weight, + ordered_did_and_weight, + "dynascore", + True, + 10, + 0, + ) + mid_and_rid_to_perf = {} + did_to_rid = {} + for dataset in self.dataset_repository.get_all(): + did_to_rid[dataset.id] = dataset.rid + rid_to_did_to_weight = {} + for did_and_weight in ordered_did_and_weight: + rid = did_to_rid[did_and_weight["did"]] + if rid in rid_to_did_to_weight: + rid_to_did_to_weight[rid][did_and_weight["did"]] = did_and_weight[ + "weight" + ] + else: + rid_to_did_to_weight[rid] = { + did_and_weight["did"]: did_and_weight["weight"] + } + mid_to_name = {} + for model in self.model_repository.get_all(): + mid_to_name[model.id] = model.name + + if isinstance(dynaboard_response, tuple): + dynaboard_response = dynaboard_response[0] + + for model_results in dynaboard_response["data"]: + for dataset_results in model_results["datasets"]: + rid = did_to_rid[dataset_results["id"]] + if rid != 0: + ordered_metric_field_names = list( + map( + lambda metric: metric["field_name"], + task_dict["ordered_metrics"], + ) + ) + perf = dataset_results["scores"][ + ordered_metric_field_names.index( + task_dict["perf_metric_field_name"] + ) + ] + mid_and_rid = (model_results["model_id"], rid) + # Weighting is needed in case there are multiple scoring + # datasets for the same round. + weighted_perf = ( + perf + * rid_to_did_to_weight[rid][dataset_results["id"]] + / sum(rid_to_did_to_weight[rid].values()) + ) + if mid_and_rid in mid_and_rid_to_perf: + mid_and_rid_to_perf[ + (model_results["model_id"], rid) + ] += weighted_perf + else: + mid_and_rid_to_perf[ + (model_results["model_id"], rid) + ] = weighted_perf + query_result = [] + for (mid, rid), perf in mid_and_rid_to_perf.items(): + query_result.append( + { + "model_id": mid, + "model_name": mid_to_name[mid], + "performance": perf, + "round_id": rid, + } + ) + + response_obj = {} + for result in query_result: + round_id = result["round_id"] + model_key = f"{result['model_name']}_{result['model_id']}" + + if round_id in response_obj: + response_obj[round_id][model_key] = result["performance"] + else: + response_obj[round_id] = { + "round": round_id, + model_key: result["performance"], + } + + return ( + sorted(list(response_obj.values()), key=lambda x: x["round"]) + if response_obj + else [] + ) diff --git a/backend/app/infrastructure/repositories/score.py b/backend/app/infrastructure/repositories/score.py index 1b2ca847d..dbc6e123b 100644 --- a/backend/app/infrastructure/repositories/score.py +++ b/backend/app/infrastructure/repositories/score.py @@ -30,7 +30,7 @@ def get_scores_users_dataset_and_model_by_task_id( .filter(Model.tid == task_id) .filter(Score.did.in_(ordered_datasets_id)) ) - if unpublished_models_in_leaderboard: + if not unpublished_models_in_leaderboard: query = query.filter(Model.is_published) return query.all() diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index a61609348..56dee7349 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -199,13 +199,13 @@ def get_tasks_current_round(self, filters: dict = {}): instances = query.all() return self.instance_converter.instance_to_dict(instances) - def get_task_with_round_info(self, task_code: str): - return ( - self.session.query(self.model, Round) - .filter(self.model.task_code == task_code) - .join( - Round, - (Round.tid == self.model.id) & (Round.rid == self.model.cur_round), - ) - .first() - ) + def get_task_with_round_info(self, task_code_or_id: str): + query = self.session.query(self.model, Round) + if isinstance(task_code_or_id, int) or task_code_or_id.isdigit(): + query = query.filter(self.model.id == int(task_code_or_id)) + else: + query = query.filter(self.model.task_code == task_code_or_id) + return query.join( + Round, + (Round.tid == self.model.id) & (Round.rid == self.model.cur_round), + ).first() diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 242b6cdf4..ee46f64a0 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -101,7 +101,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(data), }, - includeCredentials + includeCredentials, ); } @@ -210,7 +210,7 @@ export default class ApiService { } getTrends(taskId) { - return this.fetch(`${this.domain}/tasks/${taskId}/trends`, { + return this.fetch(`${this.alternateDomain}/task/trends/${taskId}`, { method: "GET", }); } @@ -222,7 +222,7 @@ export default class ApiService { sort, sortDirection, metricWeights, - datasetWeights + datasetWeights, ) { const pageQuery = `limit=${limit || 10}&offset=${offset || 0}`; const sortQuery = @@ -236,7 +236,7 @@ export default class ApiService { const datasetWeightsQuery = datasetWeights ? `&ordered_scoring_dataset_weights=${encodeURIComponent( - datasetWeights.join("|") + datasetWeights.join("|"), )}` : ""; @@ -299,7 +299,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -316,12 +316,12 @@ export default class ApiService { `${ this.domain }/contexts/${tid}/${rid}/${method}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}`, { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -330,7 +330,7 @@ export default class ApiService { rid, tags = [], context_tags = [], - annotator_id = null + annotator_id = null, ) { const includeCredentials = this.mode !== "mturk"; @@ -341,12 +341,12 @@ export default class ApiService { : ""; return this.doFetch( `${this.domain}/examples/${tid}/${rid}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}${context_tags_query}${annotator_query}`, { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -357,17 +357,17 @@ export default class ApiService { maxNumFlags, minNumDisagreements, maxNumDisagreements, - tags = [] + tags = [], ) { return this.fetch( `${ this.domain }/examples/${tid}/${rid}/filtered/${minNumFlags}/${maxNumFlags}/${minNumDisagreements}/${maxNumDisagreements}?tags=${encodeURIComponent( - tags.join("|") + tags.join("|"), )}`, { method: "GET", - } + }, ); } @@ -387,7 +387,7 @@ export default class ApiService { `${this.domain}/notifications?limit=${limit || 10}&offset=${offset || 0}`, { method: "GET", - } + }, ); } @@ -398,7 +398,7 @@ export default class ApiService { }`, { method: "GET", - } + }, ); } @@ -409,7 +409,7 @@ export default class ApiService { }`, { method: "GET", - } + }, ); } @@ -418,7 +418,7 @@ export default class ApiService { `${this.domain}/users/${userId}/forks?limit=${limit}&offset=${offset}`, { method: "GET", - } + }, ); } @@ -427,7 +427,7 @@ export default class ApiService { `${this.domain}/users/${userId}/snapshots?limit=${limit}&offset=${offset}`, { method: "GET", - } + }, ); } @@ -499,7 +499,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials + includeCredentials, ); } @@ -524,7 +524,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials + includeCredentials, ); } @@ -544,7 +544,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials + includeCredentials, ); } @@ -566,7 +566,7 @@ export default class ApiService { method: "POST", body: JSON.stringify(data), }, - false + false, ); } @@ -587,7 +587,7 @@ export default class ApiService { `${this.domain}/task_proposals/all/${page}/${pageLimit}`, { method: "GET", - } + }, ); } @@ -596,7 +596,7 @@ export default class ApiService { `${this.domain}/task_proposals/user/${page}/${pageLimit}`, { method: "GET", - } + }, ); } @@ -662,7 +662,7 @@ export default class ApiService { { method: "PUT", body: JSON.stringify(data), - } + }, ); } @@ -747,7 +747,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - } + }, ); } @@ -765,7 +765,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - } + }, ); } @@ -780,7 +780,7 @@ export default class ApiService { metadata, modelWrong, tag = null, - modelEndpointName = null + modelEndpointName = null, ) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( @@ -801,7 +801,7 @@ export default class ApiService { model_endpoint_name: modelEndpointName, }), }, - includeCredentials + includeCredentials, ); } @@ -821,7 +821,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_configuration/${name}`, { method: "GET", - } + }, ); } @@ -834,7 +834,7 @@ export default class ApiService { orderedDatasetWeights, totalCount, description, - name + name, ) { return this.fetch(`${this.domain}/tasks/${tid}/leaderboard_snapshot`, { method: "PUT", @@ -856,7 +856,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_snapshot/${name}`, { method: "GET", - } + }, ); } @@ -865,7 +865,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/disambiguate_forks_and_snapshots/${name}`, { method: "GET", - } + }, ); } @@ -889,7 +889,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; return false; - } + }, ); } } @@ -968,7 +968,7 @@ export default class ApiService { return this.doFetch(`${this.alternateDomain}/auth/refresh`, {}, true).then( (result) => { this.setToken(result.token); - } + }, ); } @@ -1011,7 +1011,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; throw new Error("Could not refresh token"); - } + }, ); return this.doFetch(url, options, true); From ff9059c41e66ba835e1f639b8e505e0001059cd0 Mon Sep 17 00:00:00 2001 From: Sara H Date: Thu, 20 Nov 2025 15:20:50 -0500 Subject: [PATCH 19/30] run prettier --- frontends/web/src/common/ApiService.js | 70 +++++++++++++------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index ee46f64a0..e1cf7561b 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -101,7 +101,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(data), }, - includeCredentials, + includeCredentials ); } @@ -222,7 +222,7 @@ export default class ApiService { sort, sortDirection, metricWeights, - datasetWeights, + datasetWeights ) { const pageQuery = `limit=${limit || 10}&offset=${offset || 0}`; const sortQuery = @@ -236,7 +236,7 @@ export default class ApiService { const datasetWeightsQuery = datasetWeights ? `&ordered_scoring_dataset_weights=${encodeURIComponent( - datasetWeights.join("|"), + datasetWeights.join("|") )}` : ""; @@ -299,7 +299,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -316,12 +316,12 @@ export default class ApiService { `${ this.domain }/contexts/${tid}/${rid}/${method}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}`, { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -330,7 +330,7 @@ export default class ApiService { rid, tags = [], context_tags = [], - annotator_id = null, + annotator_id = null ) { const includeCredentials = this.mode !== "mturk"; @@ -341,12 +341,12 @@ export default class ApiService { : ""; return this.doFetch( `${this.domain}/examples/${tid}/${rid}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}${context_tags_query}${annotator_query}`, { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -357,17 +357,17 @@ export default class ApiService { maxNumFlags, minNumDisagreements, maxNumDisagreements, - tags = [], + tags = [] ) { return this.fetch( `${ this.domain }/examples/${tid}/${rid}/filtered/${minNumFlags}/${maxNumFlags}/${minNumDisagreements}/${maxNumDisagreements}?tags=${encodeURIComponent( - tags.join("|"), + tags.join("|") )}`, { method: "GET", - }, + } ); } @@ -387,7 +387,7 @@ export default class ApiService { `${this.domain}/notifications?limit=${limit || 10}&offset=${offset || 0}`, { method: "GET", - }, + } ); } @@ -398,7 +398,7 @@ export default class ApiService { }`, { method: "GET", - }, + } ); } @@ -409,7 +409,7 @@ export default class ApiService { }`, { method: "GET", - }, + } ); } @@ -418,7 +418,7 @@ export default class ApiService { `${this.domain}/users/${userId}/forks?limit=${limit}&offset=${offset}`, { method: "GET", - }, + } ); } @@ -427,7 +427,7 @@ export default class ApiService { `${this.domain}/users/${userId}/snapshots?limit=${limit}&offset=${offset}`, { method: "GET", - }, + } ); } @@ -499,7 +499,7 @@ export default class ApiService { { method: "GET", }, - includeCredentials, + includeCredentials ); } @@ -524,7 +524,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials, + includeCredentials ); } @@ -544,7 +544,7 @@ export default class ApiService { method: "PUT", body: JSON.stringify(obj), }, - includeCredentials, + includeCredentials ); } @@ -566,7 +566,7 @@ export default class ApiService { method: "POST", body: JSON.stringify(data), }, - false, + false ); } @@ -587,7 +587,7 @@ export default class ApiService { `${this.domain}/task_proposals/all/${page}/${pageLimit}`, { method: "GET", - }, + } ); } @@ -596,7 +596,7 @@ export default class ApiService { `${this.domain}/task_proposals/user/${page}/${pageLimit}`, { method: "GET", - }, + } ); } @@ -662,7 +662,7 @@ export default class ApiService { { method: "PUT", body: JSON.stringify(data), - }, + } ); } @@ -747,7 +747,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - }, + } ); } @@ -765,7 +765,7 @@ export default class ApiService { headers: { Authorization: token ? "Bearer " + token : "None", }, - }, + } ); } @@ -780,7 +780,7 @@ export default class ApiService { metadata, modelWrong, tag = null, - modelEndpointName = null, + modelEndpointName = null ) { const includeCredentials = this.mode !== "mturk"; return this.doFetch( @@ -801,7 +801,7 @@ export default class ApiService { model_endpoint_name: modelEndpointName, }), }, - includeCredentials, + includeCredentials ); } @@ -821,7 +821,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_configuration/${name}`, { method: "GET", - }, + } ); } @@ -834,7 +834,7 @@ export default class ApiService { orderedDatasetWeights, totalCount, description, - name, + name ) { return this.fetch(`${this.domain}/tasks/${tid}/leaderboard_snapshot`, { method: "PUT", @@ -856,7 +856,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/leaderboard_snapshot/${name}`, { method: "GET", - }, + } ); } @@ -865,7 +865,7 @@ export default class ApiService { `${this.domain}/tasks/${tid}/disambiguate_forks_and_snapshots/${name}`, { method: "GET", - }, + } ); } @@ -889,7 +889,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; return false; - }, + } ); } } @@ -968,7 +968,7 @@ export default class ApiService { return this.doFetch(`${this.alternateDomain}/auth/refresh`, {}, true).then( (result) => { this.setToken(result.token); - }, + } ); } @@ -1011,7 +1011,7 @@ export default class ApiService { localStorage.removeItem("id_token"); //window.location.href = '/login'; throw new Error("Could not refresh token"); - }, + } ); return this.doFetch(url, options, true); From fb588088014bb9a09e61fe29f709740555b6c3d3 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 21 Nov 2025 01:41:17 -0500 Subject: [PATCH 20/30] Move Task Owner endpoint to FastAPI --- backend/app/api/endpoints/base/task.py | 46 ++++++++++- backend/app/domain/services/base/score.py | 2 + backend/app/domain/services/base/task.py | 82 +++++++++++++++++++ backend/app/infrastructure/models/models.py | 4 +- .../app/infrastructure/repositories/model.py | 3 + .../app/infrastructure/repositories/round.py | 3 + .../app/infrastructure/repositories/task.py | 7 ++ .../repositories/taskuserpermission.py | 28 +++++++ .../app/infrastructure/repositories/user.py | 7 ++ frontends/web/src/common/ApiService.js | 22 +++-- 10 files changed, 193 insertions(+), 11 deletions(-) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index 395b7031b..c471749a0 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -3,9 +3,11 @@ # LICENSE file in the root directory of this source tree. import os -from fastapi import APIRouter +from fastapi import APIRouter, Body, Depends, Request from fastapi.responses import FileResponse +from app.api.middleware.authentication import validate_access_token +from app.domain.auth.authentication import LoginService from app.domain.schemas.base.task import ( CheckSignConsentRequest, GetDynaboardInfoByTaskIdRequest, @@ -165,3 +167,45 @@ def get_task_with_round_and_metric_data(task_code: str): @router.get("/trends/{task_id}", response_model={}) def get_task_trends(task_id: int): return TaskService().get_task_trends(task_id) + + +@router.put("/update/{task_id}", response_model={}) +async def update_task( + task_id: int, + request: Request, + datadict=Body(...), + token_payload=Depends(validate_access_token), +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to update task data.") + return TaskService().update_task(task_id, datadict) + + +@router.get("/owners/{task_id}", response_model={}) +async def get_task_owners( + task_id: int, request: Request, token_payload=Depends(validate_access_token) +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to get task data.") + return TaskService().get_task_owners(task_id) + + +@router.put("/toggle_owner/{task_id}/{username}", response_model={}) +async def toogle_user_task_permission( + task_id: int, + username: str, + request: Request, + token_payload=Depends(validate_access_token), +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to update task data.") + return TaskService().toogle_user_task_permission(task_id, username) + + +@router.get("/get_models_in_the_loop/{task_id}", response_model={}) +async def get_models_in_the_loop( + task_id: int, request: Request, token_payload=Depends(validate_access_token) +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to get task data.") + return TaskService().get_models_in_the_loop(task_id) diff --git a/backend/app/domain/services/base/score.py b/backend/app/domain/services/base/score.py index 08171c144..e4d01c19d 100644 --- a/backend/app/domain/services/base/score.py +++ b/backend/app/domain/services/base/score.py @@ -347,6 +347,8 @@ def calculate_dynascore( def get_maximun_principal_score_by_task(self, task_id: int) -> float: yaml_file = self.task_repository.get_config_file_by_task_id(task_id)[0] + if not yaml_file: + return {"perf": 0.00} yaml_file = yaml.safe_load(yaml_file) perf_metric = yaml_file.get("perf_metric", {}) if isinstance(perf_metric, list): diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index 5ca0587aa..56215abd6 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -23,8 +23,12 @@ from app.infrastructure.repositories.example import ExampleRepository from app.infrastructure.repositories.historical_data import HistoricalDataRepository from app.infrastructure.repositories.model import ModelRepository +from app.infrastructure.repositories.round import RoundRepository from app.infrastructure.repositories.task import TaskRepository from app.infrastructure.repositories.taskcategories import TaskCategoriesRepository +from app.infrastructure.repositories.taskuserpermission import ( + TaskUserPermissionRepository, +) from app.infrastructure.repositories.user import UserRepository from app.infrastructure.repositories.validation import ValidationRepository @@ -36,10 +40,12 @@ def __init__(self): self.model_repository = ModelRepository() self.example_repository = ExampleRepository() self.score_services = ScoreService() + self.round_repository = RoundRepository() self.task_categories_repository = TaskCategoriesRepository() self.user_repository = UserRepository() self.validation_repository = ValidationRepository() self.historical_task_repository = HistoricalDataRepository() + self.task_user_permission_repository = TaskUserPermissionRepository() self.session = boto3.Session( aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), @@ -600,3 +606,79 @@ def get_task_trends(self, task_id: int): if response_obj else [] ) + + def update_task(self, task_id, data): + for field in data: + if field not in ( + "unpublished_models_in_leaderboard", + "validate_non_fooling", + "num_matching_validations", + "instructions_md", + "predictions_upload_instructions_md", + "train_file_upload_instructions_md", + "mlcube_tutorial_markdown", + "dynamic_adversarial_data_collection", + "dynamic_adversarial_data_validation", + "hidden", + "submitable", + "create_endpoint", + "build_sqs_queue", + "eval_sqs_queue", + "is_decen_task", + "task_aws_account_id", + "task_gateway_predict_prefix", + "config_yaml", + "context", + "leaderboard_description", + ): + raise HTTPException( + status_code=403, detail=f"Field {field} cannot be updated." + ) + return self.task_repository.update_task(task_id, data) + + def get_task_owners(self, task_id): + tasks = self.task_user_permission_repository.get_task_owners(task_id) + users = [] + for user in tasks: + user_name = self.user_repository.get_user_name_by_id(user["uid"]) + users.append({"user_id": user["uid"], "username": user_name["username"]}) + return users + + def toogle_user_task_permission(self, task_id: int, username: str): + user_to_toggle = self.user_repository.get_user_by_username(username) + + if (task_id, "owner") in [ + (perm.tid, perm.type) for perm in user_to_toggle.task_permissions + ]: + self.task_user_permission_repository.delete_task_user_permission( + task_id, user_to_toggle.id, "owner" + ) + print("Removed task owner: " + username) + else: + self.task_user_permission_repository.create_user_task_permission( + task_id, user_to_toggle.id, "owner" + ) + print("Added task owner: " + username) + + return {"success": "ok"} + + def get_models_in_the_loop(self, task_id: int): + rounds = self.round_repository.get_rounds_by_task_id(task_id) + models = self.model_repository.get_models_by_task_id(task_id) + rid_to_model_identifiers = {} + for round in rounds: + model_identifiers = [] + for model in models: + if model.light_model: + if model.is_published and model.deployment_status == "deployed": + model_identifiers.append( + { + "model_name": model.name, + "model_id": model.id, + "uid": model.uid, + "username": model.user.username, + "is_in_the_loop": model.is_in_the_loop, + } + ) + rid_to_model_identifiers[round.rid] = model_identifiers + return rid_to_model_identifiers diff --git a/backend/app/infrastructure/models/models.py b/backend/app/infrastructure/models/models.py index b2986f87d..5b036d4ee 100644 --- a/backend/app/infrastructure/models/models.py +++ b/backend/app/infrastructure/models/models.py @@ -375,8 +375,8 @@ class TaskUserPermission(Base): type = Column(String(255)) tid = Column(ForeignKey("tasks.id"), index=True) - task = relationship("Task") - user = relationship("User") + task = relationship("Task", backref="task_permissions") + user = relationship("User", backref="task_permissions") class Context(Base): diff --git a/backend/app/infrastructure/repositories/model.py b/backend/app/infrastructure/repositories/model.py index a4460f676..03398a11c 100644 --- a/backend/app/infrastructure/repositories/model.py +++ b/backend/app/infrastructure/repositories/model.py @@ -307,3 +307,6 @@ def get_amount_of_models_uploaded_in_hr_diff( ) .scalar() ) + + def get_models_by_task_id(self, task_id: int): + return self.session.query(self.model).filter(self.model.tid == task_id).all() diff --git a/backend/app/infrastructure/repositories/round.py b/backend/app/infrastructure/repositories/round.py index 96e767e0e..3f08579d3 100644 --- a/backend/app/infrastructure/repositories/round.py +++ b/backend/app/infrastructure/repositories/round.py @@ -59,3 +59,6 @@ def get_examples_collected_per_round(self, round_id: int, task_id: int): .filter((self.model.rid == round_id) & (self.model.tid == task_id)) .first() ) + + def get_rounds_by_task_id(self, task_id: int): + return self.session.query(self.model).filter(self.model.tid == task_id).all() diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index 56dee7349..7b4a2e10b 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -209,3 +209,10 @@ def get_task_with_round_info(self, task_code_or_id: str): Round, (Round.tid == self.model.id) & (Round.rid == self.model.cur_round), ).first() + + def update_task(self, task_id: int, update_data: dict): + self.session.query(self.model).filter(self.model.id == task_id).update( + update_data + ) + self.session.flush() + self.session.commit() diff --git a/backend/app/infrastructure/repositories/taskuserpermission.py b/backend/app/infrastructure/repositories/taskuserpermission.py index abb56ddf6..011e54500 100644 --- a/backend/app/infrastructure/repositories/taskuserpermission.py +++ b/backend/app/infrastructure/repositories/taskuserpermission.py @@ -5,6 +5,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from sqlalchemy import and_ from app.infrastructure.models.models import TaskUserPermission from app.infrastructure.repositories.abstract import AbstractRepository @@ -30,3 +31,30 @@ def get_task_permissions_by_user_id(self, user_id: int): return [ self.instance_converter.instance_to_dict(instance) for instance in instances ] + + def get_task_owners(self, task_id: int): + instances = ( + self.session.query(self.model) + .filter(self.model.tid == task_id) + .filter(self.model.type == "owner") + .all() + ) + return [ + self.instance_converter.instance_to_dict(instance) for instance in instances + ] + + def delete_task_user_permission(self, task_id: int, user_id: int, type: str): + self.session.query(self.model).filter( + and_( + self.model.uid == user_id, + self.model.type == type, + self.model.tid == task_id, + ) + ).delete() + with self.session as session: + session.commit() + + def create_user_task_permission( + self, task_id: int, user_id: int, permission_type: str + ): + return self.add({"tid": task_id, "uid": user_id, "type": permission_type}) diff --git a/backend/app/infrastructure/repositories/user.py b/backend/app/infrastructure/repositories/user.py index 7adad49b0..35b7c8760 100644 --- a/backend/app/infrastructure/repositories/user.py +++ b/backend/app/infrastructure/repositories/user.py @@ -119,3 +119,10 @@ def download_users_info(self): return self.session.query( self.model.id, self.model.email, self.model.username ).all() + + def get_user_by_username(self, username: str): + return ( + self.session.query(self.model) + .filter(self.model.username == username) + .first() + ) diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index e1cf7561b..e823ae3d0 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -625,20 +625,23 @@ export default class ApiService { } updateTask(tid, data) { - return this.fetch(`${this.domain}/tasks/update/${tid}`, { + return this.fetch(`${this.alternateDomain}/task/update/${tid}`, { method: "PUT", body: JSON.stringify(data), }); } toggleOwner(tid, username) { - return this.fetch(`${this.domain}/tasks/toggle_owner/${tid}/${username}`, { - method: "PUT", - }); + return this.fetch( + `${this.alternateDomain}/task/toggle_owner/${tid}/${username}`, + { + method: "PUT", + } + ); } getOwners(tid, username) { - return this.fetch(`${this.domain}/tasks/owners/${tid}`, { + return this.fetch(`${this.alternateDomain}/task/owners/${tid}`, { method: "GET", }); } @@ -673,9 +676,12 @@ export default class ApiService { } getModelIdentifiersForTargetSelection(tid) { - return this.fetch(`${this.domain}/tasks/get_models_in_the_loop/${tid}`, { - method: "GET", - }); + return this.fetch( + `${this.alternateDomain}/task/get_models_in_the_loop/${tid}`, + { + method: "GET", + } + ); } getModelIdentifiers(tid) { From 14c2332267a22b3ab3e37b5232cdadac177c6270 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Hincapi=C3=A9=20M?= <43832784+shincap8@users.noreply.github.com> Date: Fri, 21 Nov 2025 09:21:20 -0500 Subject: [PATCH 21/30] Update backend/app/domain/services/base/score.py Co-authored-by: Rafael Mosquera --- backend/app/domain/services/base/score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/domain/services/base/score.py b/backend/app/domain/services/base/score.py index e4d01c19d..d39e1fa1c 100644 --- a/backend/app/domain/services/base/score.py +++ b/backend/app/domain/services/base/score.py @@ -345,7 +345,7 @@ def calculate_dynascore( ) return converted_data - def get_maximun_principal_score_by_task(self, task_id: int) -> float: + def get_maximum_principal_score_by_task(self, task_id: int) -> float: yaml_file = self.task_repository.get_config_file_by_task_id(task_id)[0] if not yaml_file: return {"perf": 0.00} From be8520ba55d832a464b662e0c092333b14875e38 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 21 Nov 2025 09:35:44 -0500 Subject: [PATCH 22/30] Delete unnecesary if and use hidden instead of active --- backend/app/domain/services/base/score.py | 2 -- backend/app/domain/services/base/task.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/app/domain/services/base/score.py b/backend/app/domain/services/base/score.py index e4d01c19d..08171c144 100644 --- a/backend/app/domain/services/base/score.py +++ b/backend/app/domain/services/base/score.py @@ -347,8 +347,6 @@ def calculate_dynascore( def get_maximun_principal_score_by_task(self, task_id: int) -> float: yaml_file = self.task_repository.get_config_file_by_task_id(task_id)[0] - if not yaml_file: - return {"perf": 0.00} yaml_file = yaml.safe_load(yaml_file) perf_metric = yaml_file.get("perf_metric", {}) if isinstance(perf_metric, list): diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index 56215abd6..f3b77c0ec 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -382,7 +382,7 @@ def get_task_consent(self, task_id: int): def get_tasks(self, exclude_hidden: bool = True): filters = { - "active": int(exclude_hidden), + "hidden": int(not exclude_hidden), } tasks = self.task_repository.get_tasks_current_round(filters) converted_tasks = [] From 21c8a88d3a31804562fcb8538bd1bfc5169d98f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sara=20Hincapi=C3=A9=20M?= <43832784+shincap8@users.noreply.github.com> Date: Fri, 21 Nov 2025 09:39:16 -0500 Subject: [PATCH 23/30] Update backend/app/domain/services/base/task.py Co-authored-by: Rafael Mosquera --- backend/app/domain/services/base/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index f3b77c0ec..b51fb1246 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -492,7 +492,7 @@ def get_task_metrics_meta(self, task): def get_task_trends(self, task_id: int): """ - Get top perform models and its round wise performance metrics at task level + Get top performance models and its round-wise performance metrics at task level It will fetch only top 10 models and its round wise performance metrics :param tid: Task id :return: Json Object From 03ade2fc75f4e63af3afc854be50cf866fad4384 Mon Sep 17 00:00:00 2001 From: Sara H Date: Sat, 22 Nov 2025 14:11:23 -0500 Subject: [PATCH 24/30] move endpoints from Bottle to FastAPI --- backend/app/api/endpoints/base/round.py | 16 ++++++++- backend/app/api/endpoints/base/score.py | 2 +- backend/app/api/endpoints/base/task.py | 18 ++++++++++ backend/app/domain/schemas/base/round.py | 23 ++++++++++++ backend/app/domain/services/base/round.py | 12 +++++++ backend/app/domain/services/base/task.py | 29 +++++++++++++++ .../infrastructure/repositories/example.py | 35 ++++++++++++++++++- .../app/infrastructure/repositories/task.py | 8 +++++ frontends/web/src/common/ApiService.js | 20 +++++++---- 9 files changed, 153 insertions(+), 10 deletions(-) create mode 100644 backend/app/domain/schemas/base/round.py diff --git a/backend/app/api/endpoints/base/round.py b/backend/app/api/endpoints/base/round.py index 18c5dbb7a..dce0ff969 100644 --- a/backend/app/api/endpoints/base/round.py +++ b/backend/app/api/endpoints/base/round.py @@ -2,8 +2,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter +from typing import List +from fastapi import APIRouter, Depends, Request + +from app.api.middleware.authentication import validate_access_token +from app.domain.auth.authentication import LoginService +from app.domain.schemas.base.round import RoundResponse from app.domain.services.base.round import RoundService @@ -13,3 +18,12 @@ @router.get("/get_examples_collected_per_round/{round_id}-{task_id}") def get_examples_collected_per_round(round_id: int, task_id: int): return RoundService().get_examples_collected_per_round(round_id, task_id) + + +@router.get("/get_all_task_rounds/{task_id}", response_model=List[RoundResponse]) +async def get_all_rounds( + task_id: int, request: Request, token_payload=Depends(validate_access_token) +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to get rounds.") + return RoundService().get_rounds_by_task_id(task_id) diff --git a/backend/app/api/endpoints/base/score.py b/backend/app/api/endpoints/base/score.py index 139000e1e..35ad5683d 100644 --- a/backend/app/api/endpoints/base/score.py +++ b/backend/app/api/endpoints/base/score.py @@ -18,7 +18,7 @@ @router.get("/get_maximun_principal_score_by_task/{task_id}", response_model={}) async def get_maximun_principal_score_by_task(task_id: int): - return ScoreService().get_maximun_principal_score_by_task(task_id) + return ScoreService().get_maximum_principal_score_by_task(task_id) @router.post("/read_users_score_csv/", response_model=CsvResponseModel) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index c471749a0..925faa855 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -209,3 +209,21 @@ async def get_models_in_the_loop( if not LoginService().is_admin_or_owner(task_id, request): raise PermissionError("Unauthorized access to get task data.") return TaskService().get_models_in_the_loop(task_id) + + +@router.post("/create_round/{task_id}", response_model={}) +async def create_round( + task_id: int, request: Request, token_payload=Depends(validate_access_token) +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to create round.") + return TaskService().create_round(task_id) + + +@router.get("/get_model_identifiers/{task_id}", response_model={}) +async def get_model_identifiers( + task_id: int, request: Request, token_payload=Depends(validate_access_token) +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to get model identifiers.") + return TaskService().get_model_identifiers(task_id) diff --git a/backend/app/domain/schemas/base/round.py b/backend/app/domain/schemas/base/round.py new file mode 100644 index 000000000..1d3de740c --- /dev/null +++ b/backend/app/domain/schemas/base/round.py @@ -0,0 +1,23 @@ +# Copyright (c) MLCommons and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from datetime import datetime, time +from typing import Optional + +from pydantic import BaseModel + + +class RoundResponse(BaseModel): + id: int + tid: int + rid: int + url: Optional[str] + desc: Optional[str] + longdesc: Optional[str] + total_fooled: int + total_verified_fooled: int + total_collected: int + total_time_spent: Optional[time] + start_datetime: Optional[datetime] + end_datetime: Optional[datetime] diff --git a/backend/app/domain/services/base/round.py b/backend/app/domain/services/base/round.py index 414056667..d70180024 100644 --- a/backend/app/domain/services/base/round.py +++ b/backend/app/domain/services/base/round.py @@ -24,3 +24,15 @@ def get_examples_collected_per_round(self, round_id: int, task_id: int): return self.round_repository.get_examples_collected_per_round( round_id, task_id ).total_collected + + def get_rounds_by_task_id(self, task_id: int): + rounds = self.round_repository.get_rounds_by_task_id(task_id) + rounds_dicts = [] + for round_instance in rounds: + rounds_dicts.append( + self.round_repository.instance_converter.instance_to_dict( + round_instance + ) + ) + rounds_dicts.sort(key=lambda r: r["rid"]) + return rounds_dicts diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index f3b77c0ec..eb83e88ea 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -5,6 +5,7 @@ import json import os import random +import secrets from ast import literal_eval from typing import Union @@ -682,3 +683,31 @@ def get_models_in_the_loop(self, task_id: int): ) rid_to_model_identifiers[round.rid] = model_identifiers return rid_to_model_identifiers + + def create_round(self, task_id: int): + task = self.task_repository.get_task_info_by_task_id(task_id).__dict__ + self.task_repository.increment_task_round(task_id) + self.round_repository.add( + { + "tid": task_id, + "rid": task["cur_round"] + 1, + "secret": secrets.token_hex(), + } + ) + return {"success": "ok"} + + def get_model_identifiers(self, task_id): + models = self.model_repository.get_models_by_task_id(task_id) + model_identifiers = [] + for model in models: + model_identifiers.append( + { + "model_name": model.name, + "model_id": model.id, + "deployment_status": model.deployment_status, + "is_published": bool(model.is_published), + "uid": model.uid, + "username": model.user.username, + } + ) + return model_identifiers diff --git a/backend/app/infrastructure/repositories/example.py b/backend/app/infrastructure/repositories/example.py index 7d2117f9e..4d9bdd5d6 100644 --- a/backend/app/infrastructure/repositories/example.py +++ b/backend/app/infrastructure/repositories/example.py @@ -7,7 +7,7 @@ # LICENSE file in the root directory of this source tree. from pydantic import Json -from sqlalchemy import func +from sqlalchemy import exists, func, not_ from app.infrastructure.models.models import Context, Example, Round, Validation from app.infrastructure.repositories.abstract import AbstractRepository @@ -223,3 +223,36 @@ def get_used_models_by_user_id_and_task_id(self, user_id: int, task_id: int): .distinct() .all() ) + + def get_examples_by_task_id_and_round_id_with_validations_ids( + self, task_id: int, round_id: int + ): + try: + validation_ids_concat = func.group_concat(Validation.id) + validations_query = ( + self.session.query(Example, validation_ids_concat) + .join(Context, Example.cid == Context.id) + .join(Round, Context.r_realid == Round.id) + .filter(Round.tid == task_id) + .filter(Round.rid == round_id) + .join(Validation, Validation.eid == Example.id) + .group_by(Validation.eid) + ) + empty_concat = func.group_concat("") + + no_validations_query = ( + self.session.query(Example, empty_concat) + .join(Context, Example.cid == Context.id) + .join(Round, Context.r_realid == Round.id) + .filter(Round.tid == task_id) + .filter(Round.rid == round_id) + .filter(not_(exists().where(Validation.eid == Example.id))) + .group_by(Example.id) + ) + return validations_query.union(no_validations_query).all() + except Exception as e: + # Handle any database-related exceptions + print( + f"Error in get_examples_by_task_id_and_round_id_with_validations_ids: {e}" + ) + return [] diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index 7b4a2e10b..5f26cf330 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -216,3 +216,11 @@ def update_task(self, task_id: int, update_data: dict): ) self.session.flush() self.session.commit() + + def increment_task_round(self, task_id: int): + with self.session as session: + session.query(self.model).filter(self.model.id == task_id).update( + {self.model.cur_round: self.model.cur_round + 1} + ) + session.flush() + session.commit() diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index e823ae3d0..bd8fa8f8d 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -647,9 +647,12 @@ export default class ApiService { } getRounds(tid) { - return this.fetch(`${this.domain}/tasks/get_all_rounds/${tid}`, { - method: "GET", - }); + return this.fetch( + `${this.alternateDomain}/round/get_all_task_rounds/${tid}`, + { + method: "GET", + } + ); } activateTask(tid, config_yaml) { @@ -670,7 +673,7 @@ export default class ApiService { } createRound(tid) { - return this.fetch(`${this.domain}/tasks/create_round/${tid}`, { + return this.fetch(`${this.alternateDomain}/task/create_round/${tid}`, { method: "POST", }); } @@ -685,9 +688,12 @@ export default class ApiService { } getModelIdentifiers(tid) { - return this.fetch(`${this.domain}/tasks/get_model_identifiers/${tid}`, { - method: "GET", - }); + return this.fetch( + `${this.alternateDomain}/task/get_model_identifiers/${tid}`, + { + method: "GET", + } + ); } getAvailableDatasetAccessTypes() { From 311af4ab981d623e2b6eae8053cf3f7052d45368 Mon Sep 17 00:00:00 2001 From: Sara H Date: Sat, 22 Nov 2025 14:40:51 -0500 Subject: [PATCH 25/30] remove get_examples_by_task_id_and_round_id_with_validations_ids not using it at the moment --- .../infrastructure/repositories/example.py | 35 +------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/backend/app/infrastructure/repositories/example.py b/backend/app/infrastructure/repositories/example.py index 4d9bdd5d6..7d2117f9e 100644 --- a/backend/app/infrastructure/repositories/example.py +++ b/backend/app/infrastructure/repositories/example.py @@ -7,7 +7,7 @@ # LICENSE file in the root directory of this source tree. from pydantic import Json -from sqlalchemy import exists, func, not_ +from sqlalchemy import func from app.infrastructure.models.models import Context, Example, Round, Validation from app.infrastructure.repositories.abstract import AbstractRepository @@ -223,36 +223,3 @@ def get_used_models_by_user_id_and_task_id(self, user_id: int, task_id: int): .distinct() .all() ) - - def get_examples_by_task_id_and_round_id_with_validations_ids( - self, task_id: int, round_id: int - ): - try: - validation_ids_concat = func.group_concat(Validation.id) - validations_query = ( - self.session.query(Example, validation_ids_concat) - .join(Context, Example.cid == Context.id) - .join(Round, Context.r_realid == Round.id) - .filter(Round.tid == task_id) - .filter(Round.rid == round_id) - .join(Validation, Validation.eid == Example.id) - .group_by(Validation.eid) - ) - empty_concat = func.group_concat("") - - no_validations_query = ( - self.session.query(Example, empty_concat) - .join(Context, Example.cid == Context.id) - .join(Round, Context.r_realid == Round.id) - .filter(Round.tid == task_id) - .filter(Round.rid == round_id) - .filter(not_(exists().where(Validation.eid == Example.id))) - .group_by(Example.id) - ) - return validations_query.union(no_validations_query).all() - except Exception as e: - # Handle any database-related exceptions - print( - f"Error in get_examples_by_task_id_and_round_id_with_validations_ids: {e}" - ) - return [] From 0a8b402fb8fa4ab462ca17b6d1dd6785fa0b8900 Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 24 Nov 2025 16:52:32 -0500 Subject: [PATCH 26/30] Move Forgot password pipeline to Backend --- backend/app/api/endpoints/auth.py | 24 +++++- backend/app/domain/auth/authentication.py | 74 ++++++++++++++++++- backend/app/domain/helpers/helper.py | 22 ++++++ backend/app/domain/schemas/auth/auth.py | 8 ++ backend/app/domain/services/base/user.py | 9 +++ .../app/infrastructure/repositories/user.py | 29 ++++++++ frontends/web/src/common/ApiService.js | 15 ++-- 7 files changed, 171 insertions(+), 10 deletions(-) create mode 100644 backend/app/domain/helpers/helper.py diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 38dc59c43..0af9aefe2 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -2,11 +2,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter, Depends, Header, Request, Response +from fastapi import APIRouter, Depends, Header, Query, Request, Response from app.api.middleware.authentication import validate_access_token from app.domain.auth.authentication import LoginService -from app.domain.schemas.auth.auth import CreateUserRequest, LoginRequest, LoginResponse +from app.domain.schemas.auth.auth import ( + CreateUserRequest, + LoginRequest, + LoginResponse, + NewPasswordRequest, + RecoverPasswordRequest, +) router = APIRouter() @@ -43,3 +49,17 @@ async def logout( request: Request, response: Response, token_payload=Depends(validate_access_token) ): return LoginService().logout(request, response) + + +@router.post("/recover/initiate") +async def recover_password(model: RecoverPasswordRequest, request: Request): + return LoginService().initiate_password_recovery(model.email, request) + + +@router.post("/recover/resolve") +async def resolve_recovery( + model: NewPasswordRequest, forgot_token: str = Query(..., alias="forgot_token") +): + return LoginService().resolve_password_recovery( + forgot_token, model.email, model.new_password + ) diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index 19782dd3b..b7beeed4c 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -11,7 +11,10 @@ from jose import jwt from werkzeug.security import check_password_hash, generate_password_hash +import app.domain.helpers.helper as util +from app.domain.helpers.email import EmailHelper from app.domain.helpers.exceptions import ( + bad_token, credentials_exception, password_is_incorrect, refresh_token_expired, @@ -43,6 +46,8 @@ def __init__(self) -> None: self.badges_repository = BadgeRepository() self.refresh_token_repository = RefreshTokenRepository() self.users_repository = UserRepository() + self.email_helper = EmailHelper() + self.email_sender = os.getenv("MAIL_LOGIN") def get_hashed_password(self, password: str) -> str: return generate_password_hash(password) @@ -91,10 +96,10 @@ def set_refresh_token(self, response, user_id: int) -> str: path="/", expires=cookie_expires, # For localhost testing set secure to False in Prod to True - secure=True, + secure=False, samesite="lax", # For localhost testing set domain to localhost - # domain="localhost" + domain="localhost", ) return refresh_token @@ -292,3 +297,68 @@ def cleanup_old_refresh_tokens(self, user_id: int): except Exception as e: print(f"Error cleaning up old refresh tokens: {e}") pass + + def initiate_password_recovery(self, email: str, request: Request): + """Initiate password recovery process by sending a recovery email with temporal token""" + parsed_origin_url = request.headers.get("origin") + if not ( + (hasattr(parsed_origin_url, "hostname")) + and (parsed_origin_url.hostname is not None) + and ( + parsed_origin_url.hostname + in [ + "dynabench.org", + "dev.dynabench.org", + "www.dynabench.org", + "beta.dynabench.org", + "api.dynabench.org", + ] + ) + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid origin header", + ) + + user = self.users_service.get_by_email(email) + if not user: + user_does_not_exist() + try: + forgot_password_token = secrets.token_hex(64) + expiry_datetime = datetime.now() + timedelta(hours=4) + self.users_service.store_password_recovery_token( + user["id"], forgot_password_token, expiry_datetime + ) + ui_server_host = util.parse_url(request.url._url, parsed_origin_url) + + self.email_helper.send( + contact=email, + cc_contact=self.email_sender, + template_name="forgot_password.txt", + msg_dict={ + "username": user["username"], + "token": forgot_password_token, + "ui_server_host": ui_server_host, + }, + subject="Password Reset Request", + ) + + return {"status": "success"} + except Exception as e: + print(f"Error initiating password recovery: {e}") + + def resolve_password_recovery(self, forgot_token, email, new_password): + """Resolve password recovery by validating the token and updating the user's password""" + print("password", new_password) + user = self.users_service.get_by_forgot_token(forgot_token) + if not user: + print("User does not exist") + user_does_not_exist() + if datetime.now() > user["forgot_password_token_expiry_date"]: + bad_token() + if user["email"] != email: + user_does_not_exist() + + self.users_repository.update_password( + user["id"], self.get_hashed_password(new_password) + ) diff --git a/backend/app/domain/helpers/helper.py b/backend/app/domain/helpers/helper.py new file mode 100644 index 000000000..9d97ade4d --- /dev/null +++ b/backend/app/domain/helpers/helper.py @@ -0,0 +1,22 @@ +# Copyright (c) MLCommons and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from urllib.parse import urlparse + + +def parse_url(url, host_name=None): + """ + parse and extract the host name and server scheme from request url + :param url: + :return: url hostname {https://dynabench.org} + """ + + try: + if not host_name: + parsed_uri = urlparse(url) + formed_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_uri) + return formed_url + return host_name + except Exception: + return "https://dynabench.org" diff --git a/backend/app/domain/schemas/auth/auth.py b/backend/app/domain/schemas/auth/auth.py index 2c85b4fc1..c51ff9d2d 100644 --- a/backend/app/domain/schemas/auth/auth.py +++ b/backend/app/domain/schemas/auth/auth.py @@ -29,3 +29,11 @@ class LoginResponse(BaseModel): class TokenPayload(BaseModel): access_token: str token_type: str + + +class RecoverPasswordRequest(BaseModel): + email: EmailStr + + +class NewPasswordRequest(RecoverPasswordRequest): + new_password: str diff --git a/backend/app/domain/services/base/user.py b/backend/app/domain/services/base/user.py index 4166d5d96..c444a2607 100644 --- a/backend/app/domain/services/base/user.py +++ b/backend/app/domain/services/base/user.py @@ -91,3 +91,12 @@ def get_user_basics_by_id(self, user_id: int): "id": user_id, "task_permissions": task_permissions, } + + def store_password_recovery_token(self, user_id: int, token: str, expires_at): + self.user_repository.store_password_recovery_token(user_id, token, expires_at) + + def get_by_forgot_token(self, forgot_token: str): + return self.user_repository.get_by_forgot_token(forgot_token) + + def update_password(self, user_id: int, new_password: str): + self.user_repository.update_password(user_id, new_password) diff --git a/backend/app/infrastructure/repositories/user.py b/backend/app/infrastructure/repositories/user.py index 35b7c8760..6f5f867d6 100644 --- a/backend/app/infrastructure/repositories/user.py +++ b/backend/app/infrastructure/repositories/user.py @@ -126,3 +126,32 @@ def get_user_by_username(self, username: str): .filter(self.model.username == username) .first() ) + + def store_password_recovery_token(self, user_id: int, token: str, expires_at): + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + { + self.model.forgot_password_token: token, + self.model.forgot_password_token_expiry_date: expires_at, + } + ) + session.commit() + + def get_by_forgot_token(self, forgot_token: str): + instance = ( + self.session.query(self.model) + .filter(self.model.forgot_password_token == forgot_token) + .first() + ) + return self.instance_converter.instance_to_dict(instance) + + def update_password(self, user_id: int, new_hashed_password: str): + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + { + self.model.password: new_hashed_password, + self.model.forgot_password_token: None, + self.model.forgot_password_token_expiry_date: None, + } + ) + session.commit() diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index bd8fa8f8d..3389d526b 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -77,17 +77,20 @@ export default class ApiService { } forgotPassword(email) { - return this.fetch(`${this.domain}/recover/initiate`, { + return this.fetch(`${this.alternateDomain}/auth/recover/initiate`, { method: "POST", - body: JSON.stringify({ email }), + body: JSON.stringify({ email: email }), }); } resetPassword({ email, password, token }) { - return this.fetch(`${this.domain}/recover/resolve/${token}`, { - method: "POST", - body: JSON.stringify({ email, password }), - }); + return this.fetch( + `${this.alternateDomain}/auth/recover/resolve?forgot_token=${token}`, + { + method: "POST", + body: JSON.stringify({ email: email, new_password: password }), + } + ); } updateExample(id, data, uid = null) { From d6ba66bd63d803142894e0e960d659392c4b3570 Mon Sep 17 00:00:00 2001 From: Sara H Date: Mon, 24 Nov 2025 16:55:57 -0500 Subject: [PATCH 27/30] undo local config --- backend/app/domain/auth/authentication.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/domain/auth/authentication.py b/backend/app/domain/auth/authentication.py index b7beeed4c..fc4111574 100644 --- a/backend/app/domain/auth/authentication.py +++ b/backend/app/domain/auth/authentication.py @@ -96,10 +96,10 @@ def set_refresh_token(self, response, user_id: int) -> str: path="/", expires=cookie_expires, # For localhost testing set secure to False in Prod to True - secure=False, + secure=True, samesite="lax", # For localhost testing set domain to localhost - domain="localhost", + # domain="localhost", ) return refresh_token From 8311fc5d6a02be840e6b694397e517468ebe92a6 Mon Sep 17 00:00:00 2001 From: Sara H Date: Tue, 25 Nov 2025 18:30:53 -0500 Subject: [PATCH 28/30] move more Bottle Endpoints to FastAPI Backend --- backend/app/api/endpoints/base/dataset.py | 36 +++++++++- backend/app/api/endpoints/base/task.py | 36 +++++++++- backend/app/domain/schemas/base/dataset.py | 11 +++ backend/app/domain/schemas/base/task.py | 4 ++ backend/app/domain/services/base/dataset.py | 53 ++++++++++++++ backend/app/domain/services/base/task.py | 70 +++++++++++++++++++ backend/app/infrastructure/models/models.py | 17 ++++- .../infrastructure/repositories/dataset.py | 18 +++++ .../infrastructure/repositories/example.py | 52 +++++++++++++- .../app/infrastructure/repositories/model.py | 20 ++++++ .../app/infrastructure/repositories/score.py | 8 +++ frontends/web/src/common/ApiService.js | 14 ++-- .../TaskOwnerPageComponents/Datasets.js | 3 + .../pages/Submissions/SubmitPrediction.tsx | 3 + 14 files changed, 332 insertions(+), 13 deletions(-) diff --git a/backend/app/api/endpoints/base/dataset.py b/backend/app/api/endpoints/base/dataset.py index 7c2cb0ff1..5ddb9524e 100644 --- a/backend/app/api/endpoints/base/dataset.py +++ b/backend/app/api/endpoints/base/dataset.py @@ -2,8 +2,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fastapi import APIRouter, File, UploadFile +from fastapi import APIRouter, Depends, File, Request, UploadFile +from app.api.middleware.authentication import validate_access_token +from app.domain.schemas.base.dataset import UpdateDatasetInfo from app.domain.services.base.dataset import DatasetService @@ -24,3 +26,35 @@ async def create_dataset_in_db( access_type: str, ): return DatasetService().create_dataset_in_db(task_id, dataset_name, access_type) + + +@router.get("/task/{task_id}") +async def get_datasets_by_task_id(task_id: int): + return DatasetService().get_datasets_by_task_id(task_id) + + +@router.get("/get_access_types") +async def get_dataset_info_by_name(): + return DatasetService().get_dataset_info_by_name() + + +@router.get("/get_log_access_types") +async def get_log_access_types(): + return DatasetService().get_log_access_types() + + +@router.put("/update/{dataset_id}") +async def update_dataset( + dataset_id: int, + model: UpdateDatasetInfo, + request: Request, + token_payload=Depends(validate_access_token), +): + return DatasetService().update_dataset_access_type(dataset_id, request, model) + + +@router.delete("/delete/{dataset_id}") +async def delete_dataset( + dataset_id: int, request: Request, token_payload=Depends(validate_access_token) +): + return DatasetService().delete_dataset(dataset_id, request) diff --git a/backend/app/api/endpoints/base/task.py b/backend/app/api/endpoints/base/task.py index 925faa855..a70af0ffe 100644 --- a/backend/app/api/endpoints/base/task.py +++ b/backend/app/api/endpoints/base/task.py @@ -3,7 +3,7 @@ # LICENSE file in the root directory of this source tree. import os -from fastapi import APIRouter, Body, Depends, Request +from fastapi import APIRouter, Body, Depends, Query, Request from fastapi.responses import FileResponse from app.api.middleware.authentication import validate_access_token @@ -13,6 +13,7 @@ GetDynaboardInfoByTaskIdRequest, PreliminaryQuestionsRequest, SignInConsentRequest, + UpdateModelsInTheLoopRequest, UpdateTaskInstructions, UpdateYamlConfiguration, ) @@ -227,3 +228,36 @@ async def get_model_identifiers( if not LoginService().is_admin_or_owner(task_id, request): raise PermissionError("Unauthorized access to get model identifiers.") return TaskService().get_model_identifiers(task_id) + + +@router.put("/update_models_in_the_loop/{task_id}", response_model={}) +async def update_models_in_the_loop( + task_id: int, + request: Request, + model: UpdateModelsInTheLoopRequest, + token_payload=Depends(validate_access_token), +): + if not LoginService().is_admin_or_owner(task_id, request): + raise PermissionError("Unauthorized access to update models in the loop.") + return TaskService().update_models_in_the_loop(task_id, model.model_ids) + + +@router.get("/{task_id}/users", response_model={}) +async def get_user_leaderboard( + task_id: int, + limit: int = Query(5, alias="limit"), + offset: int = Query(0, alias="offset"), +): + return TaskService().get_user_leaderboard(task_id, limit, offset) + + +@router.get("/{task_id}/rounds/{round_id}/users", response_model={}) +async def get_leaderboard_by_task_and_round( + task_id: int, + round_id: int, + limit: int = Query(5, alias="limit"), + offset: int = Query(0, alias="offset"), +): + return TaskService().get_leaderboard_by_task_and_round( + task_id, round_id, limit, offset + ) diff --git a/backend/app/domain/schemas/base/dataset.py b/backend/app/domain/schemas/base/dataset.py index a66ef9850..a972f05db 100644 --- a/backend/app/domain/schemas/base/dataset.py +++ b/backend/app/domain/schemas/base/dataset.py @@ -1,3 +1,14 @@ # Copyright (c) MLCommons and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + +from pydantic import BaseModel + + +class UpdateDatasetInfo(BaseModel): + access_type: str + log_access_type: Optional[str] = None + longdesc: Optional[str] = None + rid: Optional[int] = 0 + source_url: Optional[str] = None diff --git a/backend/app/domain/schemas/base/task.py b/backend/app/domain/schemas/base/task.py index d91d9d01d..f13d1dfae 100644 --- a/backend/app/domain/schemas/base/task.py +++ b/backend/app/domain/schemas/base/task.py @@ -50,3 +50,7 @@ class CheckSignConsentRequest(BaseModel): class UpdateYamlConfiguration(BaseModel): task_id: int config_yaml: str + + +class UpdateModelsInTheLoopRequest(BaseModel): + model_ids: Optional[List[int]] = [] diff --git a/backend/app/domain/services/base/dataset.py b/backend/app/domain/services/base/dataset.py index 6d9362f4b..a02143265 100644 --- a/backend/app/domain/services/base/dataset.py +++ b/backend/app/domain/services/base/dataset.py @@ -11,13 +11,20 @@ import jsonlines from fastapi import File +from app.domain.auth.authentication import LoginService from app.domain.helpers.s3_helpers import S3Helpers +from app.domain.schemas.base.dataset import UpdateDatasetInfo +from app.infrastructure.models.models import AccessTypeEnum, LogAccessTypeEnum from app.infrastructure.repositories.dataset import DatasetRepository +from app.infrastructure.repositories.score import ScoreRepository +from app.infrastructure.repositories.task import TaskRepository class DatasetService: def __init__(self): self.dataset_repository = DatasetRepository() + self.score_repository = ScoreRepository() + self.task_repository = TaskRepository() self.s3_helpers = S3Helpers() def get_dataset_name_by_id(self, dataset_id: int): @@ -58,3 +65,49 @@ def upload_dataset( jsonl_contents.encode("utf-8"), f"datasets/{task_code}/{dataset_name}.jsonl" ) return "Dataset uploaded successfully" + + def get_datasets_by_task_id(self, task_id: int): + datasets_list = [] + datasets = self.dataset_repository.get_datasets_by_task_id(task_id) + if datasets: + for dataset in datasets: + datasets_list.append(dataset.__dict__) + return datasets_list + + def get_dataset_info_by_name(self): + return [enum.name for enum in AccessTypeEnum] + + def get_log_access_types(self): + return [enum.name for enum in LogAccessTypeEnum] + + def update_dataset_access_type( + self, dataset_id: int, request, model: UpdateDatasetInfo + ): + dataset = self.dataset_repository.get_dataset_info_by_id(dataset_id) + if not LoginService().is_admin_or_owner(dataset["tid"], request): + raise PermissionError("Unauthorized access to update models in the loop.") + data = model.__dict__ + for field in data.keys(): + if field not in ( + "longdesc", + "rid", + "source_url", + "access_type", + "log_access_type", + ): + raise ValueError(f"Invalid field: {field}") + self.dataset_repository.update_dataset_info(dataset_id, data) + return {"success": "ok"} + + def delete_dataset(self, dataset_id: int, request): + dataset = self.dataset_repository.get_dataset_info_by_id(dataset_id) + if not LoginService().is_admin_or_owner(dataset["tid"], request): + raise PermissionError("Unauthorized access to delete dataset.") + scores_to_delete = self.score_repository.get_scores_for_dataset(dataset_id) + + for score in scores_to_delete: + score_dict = score.__dict__ + self.score_repository.hide(score_dict["id"]) + + self.dataset_repository.hide_dataset(dataset_id) + return {"success": "ok"} diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index e60d30ba6..ea02501a6 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -711,3 +711,73 @@ def get_model_identifiers(self, task_id): } ) return model_identifiers + + def update_models_in_the_loop(self, task_id, model_ids=[]): + self.model_repository.clean_models_in_the_loop(task_id) + if len(model_ids) > 0: + for model_id in model_ids: + self.model_repository.update_model_in_the_loop(model_id) + return {"success": "ok"} + + def get_user_leaderboard(self, task_id: int, limit: int, offset: int): + """ + Return users and MER based on their examples score based on tasks + :param tid: Task id, limit: limit, offset: offset + :return: Json Object + """ + try: + task_r_realids = [] + rounds = self.round_repository.get_rounds_by_task_id(task_id) + for round_instance in rounds: + round_dict = round_instance.__dict__ + task_r_realids.append(round_dict["rid"]) + ( + query_result, + total_count, + ) = self.example_repository.getUserLeaderByRoundRealids( + task_r_realids, limit, offset + ) + return self.__construct_user_board_response_json(query_result, total_count) + + except Exception as e: + print(e) + return {"count": 0, "data": []} + + def get_leaderboard_by_task_and_round(self, task_id, round_id, limit, offset): + """ + Get top leaders based on their examples score for specific task and round + :param tid: Task id, limit: limit, offset: offset, :param rid: round id + :return: Json Object + """ + try: + round_instance = self.round_repository.get_round_info_by_round_and_task( + task_id, round_id + ).__dict__ + ( + query_result, + total_count, + ) = self.example_repository.getUserLeaderByRoundRealids( + [round_instance["id"]], limit, offset + ) + return self.__construct_user_board_response_json(query_result, total_count) + + except Exception as e: + print(e) + return {"count": 0, "data": []} + + def __construct_user_board_response_json(self, query_result, total_count=0): + list_objs = [] + for result in query_result: + obj = {} + obj["uid"] = result[0] + obj["username"] = result[1] + obj["avatar_url"] = result[2] if result[2] is not None else "" + obj["count"] = int(result[3]) + obj["MER"] = str(round(result[4] * 100, 2)) + obj["created"] = result[5] + obj["total"] = str(result[3]) + "/" + str(result[5]) + list_objs.append(obj) + if list_objs: + return {"count": total_count, "data": list_objs} + else: + return {"count": 0, "data": []} diff --git a/backend/app/infrastructure/models/models.py b/backend/app/infrastructure/models/models.py index 5b036d4ee..2ecc09f80 100644 --- a/backend/app/infrastructure/models/models.py +++ b/backend/app/infrastructure/models/models.py @@ -6,6 +6,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import enum + # coding: utf-8 from sqlalchemy import ( JSON, @@ -191,6 +193,17 @@ class Badge(Base): user = relationship("User") +class AccessTypeEnum(enum.Enum): + scoring = "scoring" + standard = "standard" + hidden = "hidden" + + +class LogAccessTypeEnum(enum.Enum): + owner = "owner" + user = "user" + + class Dataset(Base): __tablename__ = "datasets" @@ -201,8 +214,8 @@ class Dataset(Base): desc = Column(String(255)) longdesc = Column(Text) source_url = Column(Text) - access_type = Column(Enum("scoring", "standard", "hidden")) - log_access_type = Column(Enum("owner", "user")) + access_type = Column(Enum(AccessTypeEnum)) + log_access_type = Column(Enum(LogAccessTypeEnum)) tags = Column(Integer) has_downstream = Column(TINYINT(1)) weight = Column(Float) diff --git a/backend/app/infrastructure/repositories/dataset.py b/backend/app/infrastructure/repositories/dataset.py index e8449a332..314be5d71 100644 --- a/backend/app/infrastructure/repositories/dataset.py +++ b/backend/app/infrastructure/repositories/dataset.py @@ -6,6 +6,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from app.domain.schemas.base.dataset import UpdateDatasetInfo from app.infrastructure.models.models import Dataset from app.infrastructure.repositories.abstract import AbstractRepository @@ -131,3 +132,20 @@ def get_dataset_weight(self, dataset_id: int) -> dict: .filter(self.model.id == dataset_id) .one() ) + + def get_datasets_by_task_id(self, task_id: int): + return self.session.query(self.model).filter(self.model.tid == task_id).all() + + def update_dataset_info(self, dataset_id: int, update_data: UpdateDatasetInfo): + self.session.query(self.model).filter(self.model.id == dataset_id).update( + update_data + ) + with self.session as session: + session.commit() + + def hide_dataset(self, dataset_id: int): + self.session.query(self.model).filter(self.model.id == dataset_id).update( + {"tid": 0} + ) + with self.session as session: + session.commit() diff --git a/backend/app/infrastructure/repositories/example.py b/backend/app/infrastructure/repositories/example.py index 7d2117f9e..fe30ddc86 100644 --- a/backend/app/infrastructure/repositories/example.py +++ b/backend/app/infrastructure/repositories/example.py @@ -7,9 +7,16 @@ # LICENSE file in the root directory of this source tree. from pydantic import Json -from sqlalchemy import func +from sqlalchemy import desc, func -from app.infrastructure.models.models import Context, Example, Round, Validation +from app.infrastructure.models.models import ( + Context, + Example, + Round, + RoundUserExampleInfo, + User, + Validation, +) from app.infrastructure.repositories.abstract import AbstractRepository @@ -223,3 +230,44 @@ def get_used_models_by_user_id_and_task_id(self, user_id: int, task_id: int): .distinct() .all() ) + + def getUserLeaderByRoundRealids( + self, task_r_realids: list, limit: int, offset: int + ): + total_fooled_cnt = func.sum(RoundUserExampleInfo.total_fooled).label( + "total_fooled_cnt" + ) + total_verified_not_correct_fooled_cnt = func.sum( + RoundUserExampleInfo.total_verified_not_correct_fooled + ).label("total_verified_not_correct_fooled_cnt") + examples_submitted_cnt = func.sum( + RoundUserExampleInfo.examples_submitted + ).label("examples_submitted_cnt") + + verified_fooled = ( + total_fooled_cnt - total_verified_not_correct_fooled_cnt + ).label("verified_fooled") + fooling_rate = ( + (total_fooled_cnt - total_verified_not_correct_fooled_cnt) + / examples_submitted_cnt + ).label("fooling_rate") + + query_res = ( + self.session.query( + User.id, + User.username, + User.avatar_url, + verified_fooled, + fooling_rate, + examples_submitted_cnt, + ) + .join(RoundUserExampleInfo, RoundUserExampleInfo.uid == User.id) + .filter(RoundUserExampleInfo.r_realid.in_(task_r_realids)) + .group_by(RoundUserExampleInfo.uid) + .order_by(desc(examples_submitted_cnt)) + ) + results = query_res.limit(limit).offset(offset * limit).all() + + total_count = query_res.count() + + return results, total_count diff --git a/backend/app/infrastructure/repositories/model.py b/backend/app/infrastructure/repositories/model.py index 03398a11c..80271539a 100644 --- a/backend/app/infrastructure/repositories/model.py +++ b/backend/app/infrastructure/repositories/model.py @@ -310,3 +310,23 @@ def get_amount_of_models_uploaded_in_hr_diff( def get_models_by_task_id(self, task_id: int): return self.session.query(self.model).filter(self.model.tid == task_id).all() + + def clean_models_in_the_loop(self, task_id: int): + all_models_for_task = ( + self.session.query(self.model) + .filter(self.model.tid == task_id, self.model.is_in_the_loop == 1) + .all() + ) + for model in all_models_for_task: + model.is_in_the_loop = False + self.session.flush() + self.session.commit() + + def update_model_in_the_loop(self, model_id: int): + with self.session as session: + instance = ( + session.query(self.model).filter(self.model.id == model_id).first() + ) + instance.is_in_the_loop = True + session.commit() + session.flush() diff --git a/backend/app/infrastructure/repositories/score.py b/backend/app/infrastructure/repositories/score.py index dbc6e123b..32c1bdd39 100644 --- a/backend/app/infrastructure/repositories/score.py +++ b/backend/app/infrastructure/repositories/score.py @@ -118,3 +118,11 @@ def fix_f1_score(self, model_id: int): session.execute(sql, {"model_id": model_id}) session.flush() session.commit() + + def get_scores_for_dataset(self, dataset_id: int): + return self.session.query(Score).filter(Score.did == dataset_id).all() + + def hide(self, score_id: int): + with self.session as session: + session.query(Score).filter(Score.id == score_id).update({"r_realid": 0}) + session.commit() diff --git a/frontends/web/src/common/ApiService.js b/frontends/web/src/common/ApiService.js index 3389d526b..19082b34e 100644 --- a/frontends/web/src/common/ApiService.js +++ b/frontends/web/src/common/ApiService.js @@ -280,7 +280,7 @@ export default class ApiService { round === "overall" ? `/users?limit=${limit || 10}&offset=${offset || 0}` : `/rounds/${round}/users?limit=${limit || 10}&offset=${offset || 0}`; - return this.fetch(`${this.domain}/tasks/${taskId}${url}`, { + return this.fetch(`${this.alternateDomain}/task/${taskId}${url}`, { method: "GET", }); } @@ -667,7 +667,7 @@ export default class ApiService { updateModelsInTheLoop(tid, rid, data) { return this.fetch( - `${this.domain}/tasks/update_models_in_the_loop/${tid}/${rid}`, + `${this.alternateDomain}/task/update_models_in_the_loop/${tid}`, { method: "PUT", body: JSON.stringify(data), @@ -700,25 +700,25 @@ export default class ApiService { } getAvailableDatasetAccessTypes() { - return this.fetch(`${this.domain}/datasets/get_access_types`, { + return this.fetch(`${this.alternateDomain}/dataset/get_access_types`, { method: "GET", }); } getAvailableDatasetLogAccessTypes() { - return this.fetch(`${this.domain}/datasets/get_log_access_types`, { + return this.fetch(`${this.alternateDomain}/dataset/get_log_access_types`, { method: "GET", }); } getDatasets(tid) { - return this.fetch(`${this.domain}/tasks/datasets/${tid}`, { + return this.fetch(`${this.alternateDomain}/dataset/task/${tid}`, { method: "GET", }); } updateDataset(did, data) { - return this.fetch(`${this.domain}/datasets/update/${did}`, { + return this.fetch(`${this.alternateDomain}/dataset/update/${did}`, { method: "PUT", body: JSON.stringify(data), }); @@ -740,7 +740,7 @@ export default class ApiService { } deleteDataset(did) { - return this.fetch(`${this.domain}/datasets/delete/${did}`, { + return this.fetch(`${this.alternateDomain}/dataset/delete/${did}`, { method: "DELETE", }); } diff --git a/frontends/web/src/components/TaskOwnerPageComponents/Datasets.js b/frontends/web/src/components/TaskOwnerPageComponents/Datasets.js index d0848541a..458f7795f 100644 --- a/frontends/web/src/components/TaskOwnerPageComponents/Datasets.js +++ b/frontends/web/src/components/TaskOwnerPageComponents/Datasets.js @@ -67,6 +67,8 @@ const Datasets = (props) => { const config = yaml.load(props.task.config_yaml); const delta_metric_configs = config.delta_metrics ? config.delta_metrics : []; const delta_files = {}; + const token = localStorage.getItem("id_token") || ""; + axios.defaults.headers.common["Authorization"] = `Bearer ${token}`; for (const config of delta_metric_configs) { delta_files[config.type] = null; @@ -86,6 +88,7 @@ const Datasets = (props) => { dataset_name: values.name, task_code: props.task.task_code, }, + withCredentials: true, }) .then((response) => { if (response.status === 200) { diff --git a/frontends/web/src/new_front/pages/Submissions/SubmitPrediction.tsx b/frontends/web/src/new_front/pages/Submissions/SubmitPrediction.tsx index 8b2787abd..7a32186de 100644 --- a/frontends/web/src/new_front/pages/Submissions/SubmitPrediction.tsx +++ b/frontends/web/src/new_front/pages/Submissions/SubmitPrediction.tsx @@ -23,6 +23,8 @@ const SubmitPrediction = () => { reValidateMode: "onSubmit", defaultValues: initState, }); + const token = localStorage.getItem("id_token") || ""; + axios.defaults.headers.common["Authorization"] = `Bearer ${token}`; const isLogin = async () => { if (!user.id) { @@ -72,6 +74,7 @@ const SubmitPrediction = () => { task_code: taskCode, model_name: modelData.modelName.replace(/\s/g, "_"), }, + withCredentials: true, }) .then(() => { Swal.fire({ From 0bb5dd13b38701b8dc2552bb529d5a5db309a8ff Mon Sep 17 00:00:00 2001 From: Sara H Date: Tue, 25 Nov 2025 19:10:09 -0500 Subject: [PATCH 29/30] delete hide for score --- backend/app/domain/services/base/dataset.py | 5 ----- backend/app/infrastructure/repositories/score.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/backend/app/domain/services/base/dataset.py b/backend/app/domain/services/base/dataset.py index a02143265..9732bd89d 100644 --- a/backend/app/domain/services/base/dataset.py +++ b/backend/app/domain/services/base/dataset.py @@ -103,11 +103,6 @@ def delete_dataset(self, dataset_id: int, request): dataset = self.dataset_repository.get_dataset_info_by_id(dataset_id) if not LoginService().is_admin_or_owner(dataset["tid"], request): raise PermissionError("Unauthorized access to delete dataset.") - scores_to_delete = self.score_repository.get_scores_for_dataset(dataset_id) - - for score in scores_to_delete: - score_dict = score.__dict__ - self.score_repository.hide(score_dict["id"]) self.dataset_repository.hide_dataset(dataset_id) return {"success": "ok"} diff --git a/backend/app/infrastructure/repositories/score.py b/backend/app/infrastructure/repositories/score.py index 32c1bdd39..737168188 100644 --- a/backend/app/infrastructure/repositories/score.py +++ b/backend/app/infrastructure/repositories/score.py @@ -121,8 +121,3 @@ def fix_f1_score(self, model_id: int): def get_scores_for_dataset(self, dataset_id: int): return self.session.query(Score).filter(Score.did == dataset_id).all() - - def hide(self, score_id: int): - with self.session as session: - session.query(Score).filter(Score.id == score_id).update({"r_realid": 0}) - session.commit() From 4c1af61649df6417568592f107bc006b7aabd1d9 Mon Sep 17 00:00:00 2001 From: Sara H Date: Fri, 28 Nov 2025 19:06:50 -0500 Subject: [PATCH 30/30] fix methods to implement new enum type --- backend/app/domain/services/base/task.py | 3 ++- backend/app/infrastructure/repositories/dataset.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index ea02501a6..f2a05bacc 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -20,6 +20,7 @@ from app.domain.services.builder_and_evaluation.eval_utils.metrics_dicts import ( meta_metrics_dict, ) +from app.infrastructure.models.models import AccessTypeEnum from app.infrastructure.repositories.dataset import DatasetRepository from app.infrastructure.repositories.example import ExampleRepository from app.infrastructure.repositories.historical_data import HistoricalDataRepository @@ -412,7 +413,7 @@ def get_task_with_round_and_metric_data(self, task_id_or_code: Union[int, str]): scoring_dataset_list = [] for dataset in datasets: dataset_list.append({"id": dataset.id, "name": dataset.name}) - if dataset.access_type == "scoring": + if dataset.access_type == AccessTypeEnum.scoring: scoring_dataset_list.append( { "id": dataset.id, diff --git a/backend/app/infrastructure/repositories/dataset.py b/backend/app/infrastructure/repositories/dataset.py index 314be5d71..a335ade25 100644 --- a/backend/app/infrastructure/repositories/dataset.py +++ b/backend/app/infrastructure/repositories/dataset.py @@ -7,7 +7,7 @@ # LICENSE file in the root directory of this source tree. from app.domain.schemas.base.dataset import UpdateDatasetInfo -from app.infrastructure.models.models import Dataset +from app.infrastructure.models.models import AccessTypeEnum, Dataset from app.infrastructure.repositories.abstract import AbstractRepository @@ -17,7 +17,8 @@ def __init__(self) -> None: def get_scoring_datasets(self, task_id: int, dataset_name: str = None) -> dict: scoring_datasets = self.session.query(self.model).filter( - (self.model.access_type == "scoring") & (self.model.tid == task_id) + (self.model.access_type == AccessTypeEnum.scoring) + & (self.model.tid == task_id) ) if dataset_name: scoring_datasets = scoring_datasets.filter(self.model.name == dataset_name) @@ -33,7 +34,8 @@ def get_scoring_datasets(self, task_id: int, dataset_name: str = None) -> dict: def get_not_scoring_datasets(self, task_id: int) -> dict: no_scoring_datasets = self.session.query(self.model).filter( - (self.model.access_type != "scoring") & (self.model.tid == task_id) + (self.model.access_type != AccessTypeEnum.scoring) + & (self.model.tid == task_id) ) jsonl_no_scoring_datasets = [] @@ -66,7 +68,7 @@ def get_order_scoring_datasets_by_task_id(self, task_id: int) -> dict: self.session.query(self.model) .order_by(self.model.id) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) .all() ) @@ -95,7 +97,7 @@ def get_scoring_datasets_by_task_id(self, task_id: int) -> dict: return ( self.session.query(self.model.id) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) .all() ) @@ -114,7 +116,7 @@ def get_downstream_datasets(self, task_id: int) -> dict: downstream_datasets = ( self.session.query(self.model) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) ) jsonl_downstream_datasets = [] for downstream_dataset in downstream_datasets: