From 54a800c2aeb434176a8b8ab069c41aa36649bd5c Mon Sep 17 00:00:00 2001 From: Andrew Peng Date: Mon, 13 May 2024 21:33:43 -0500 Subject: [PATCH] Migrate training into celery and upload results to s3 (#1157) * upload train results to s3, add endpoint to get train results, frontend request from endpoint :art: Auto-generated directory tree for repository in Architecture.md :art: Format Python code with psf/black slight reformatting add refetch reduce duplicated code * resolve comments :art: Auto-generated directory tree for repository in Architecture.md * :art: Format Python code with psf/black * update dockerfiles so that it automatically spins up celery and redis in dev env * call createTrainspace, pass id to train() * modify get_trainspace to return detailedtrainresultsdata * resolve sonarcloud * :art: Format Python code with psf/black * resolve sonarcloud * :art: Auto-generated directory tree for repository in Architecture.md * :art: Format Python code with psf/black * pass tests * fix nits --------- Co-authored-by: andrewpeng02 --- .github/Architecture.md | 20 +- dlp-terraform/ecs/s3.tf | 15 + dlp-terraform/ecs/sqs.tf | 27 + frontend/next.config.js | 2 +- .../Image/components/ImageTrainspace.tsx | 18 +- .../Train/features/Image/redux/imageApi.ts | 5 +- .../Tabular/components/TabularTrainspace.tsx | 20 +- .../features/Tabular/redux/tabularApi.ts | 5 +- .../src/features/Train/redux/trainspaceApi.ts | 36 + .../src/features/Train/types/trainTypes.ts | 50 +- frontend/src/pages/train/[train_space_id].tsx | 306 +---- .../src/pages/train/metrics_to_charts.tsx | 137 +++ serverless/package.json | 5 +- .../src/trainspace/create_trainspace.ts | 133 ++- .../src/trainspace/get_trainspace.ts | 109 +- .../tests/create_trainspace.test.ts | 8 +- serverless/pnpm-lock.yaml | 1013 ++++++++++++++--- serverless/stacks/AppStack.ts | 2 +- training/Dockerfile | 43 +- training/Dockerfile.prod | 36 - training/README.md | 11 +- training/docker-compose.prod.yml | 10 +- training/docker-compose.yml | 28 +- training/poetry.lock | 285 ++++- training/pyproject.toml | 1 + training/tests/test_imports.py | 10 +- training/tests/test_loss_function.py | 2 +- training/tests/test_model.py | 2 +- ...est_sk_learn_default_dataset_train_test.py | 2 +- training/training/celery_app.py | 18 + training/training/celeryconfig.py | 23 + training/training/constants.py | 1 + training/training/core/authenticator.py | 6 +- training/training/core/celery/Dockerfile | 51 + training/training/core/celery/__init__.py | 0 .../training/core/{ => celery}/criterion.py | 0 .../training/core/{ => celery}/dataset.py | 0 .../training/core/{ => celery}/dl_model.py | 0 .../training/core/{ => celery}/optimizer.py | 0 training/training/core/celery/train_types.py | 74 ++ .../training/core/{ => celery}/trainer.py | 2 +- training/training/core/celery/worker.py | 194 ++++ .../routes/datasets/default/columns.py | 3 +- training/training/routes/image/image.py | 41 +- training/training/routes/image/schemas.py | 1 + training/training/routes/tabular/schemas.py | 1 + training/training/routes/tabular/tabular.py | 65 +- 47 files changed, 2146 insertions(+), 675 deletions(-) create mode 100644 dlp-terraform/ecs/s3.tf create mode 100644 dlp-terraform/ecs/sqs.tf create mode 100644 frontend/src/pages/train/metrics_to_charts.tsx delete mode 100644 training/Dockerfile.prod create mode 100644 training/training/celery_app.py create mode 100644 training/training/celeryconfig.py create mode 100644 training/training/constants.py create mode 100644 training/training/core/celery/Dockerfile create mode 100644 training/training/core/celery/__init__.py rename training/training/core/{ => celery}/criterion.py (100%) rename training/training/core/{ => celery}/dataset.py (100%) rename training/training/core/{ => celery}/dl_model.py (100%) rename training/training/core/{ => celery}/optimizer.py (100%) create mode 100644 training/training/core/celery/train_types.py rename training/training/core/{ => celery}/trainer.py (99%) create mode 100644 training/training/core/celery/worker.py diff --git a/.github/Architecture.md b/.github/Architecture.md index c790c343e..c140163df 100644 --- a/.github/Architecture.md +++ b/.github/Architecture.md @@ -30,18 +30,26 @@ | | | |- 📜 __init__.py | | | |- 📜 health_check_middleware.py | | |- 📂 core: -| | | |- 📜 trainer.py -| | | |- 📜 criterion.py -| | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop -| | | |- 📜 dataset.py : read in the dataset through URL or file upload +| | | |- 📂 celery: +| | | | |- 📜 trainer.py +| | | | |- 📜 criterion.py +| | | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop +| | | | |- 📜 train_types.py +| | | | |- 📜 dataset.py : read in the dataset through URL or file upload +| | | | |- 📜 __init__.py +| | | | |- 📜 Dockerfile +| | | | |- 📜 worker.py +| | | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now) | | | |- 📜 __init__.py | | | |- 📜 authenticator.py -| | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now) | | |- 📜 asgi.py +| | |- 📜 constants.py : list of helpful constants +| | |- 📜 celery_app.py | | |- 📜 settings.py | | |- 📜 __init__.py | | |- 📜 wsgi.py | | |- 📜 urls.py +| | |- 📜 celeryconfig.py | |- 📜 README.md | |- 📜 docker-compose.yml | |- 📜 cli.py @@ -52,7 +60,6 @@ | |- 📜 manage.py | |- 📜 environment.yml | |- 📜 docker-compose.prod.yml -| |- 📜 Dockerfile.prod ``` ## Frontend Architecture @@ -210,6 +217,7 @@ | | |- 📂 pages: | | | |- 📂 train: | | | | |- 📜 [train_space_id].tsx +| | | | |- 📜 metrics_to_charts.tsx | | | | |- 📜 index.tsx | | | |- 📜 _app.tsx | | | |- 📜 forgot.tsx diff --git a/dlp-terraform/ecs/s3.tf b/dlp-terraform/ecs/s3.tf new file mode 100644 index 000000000..2631fc1d5 --- /dev/null +++ b/dlp-terraform/ecs/s3.tf @@ -0,0 +1,15 @@ +resource "aws_s3_bucket" "s3bucket_executions" { + bucket = "dlp-executions" + + tags = { + Name = "Execution data" + } +} +resource "aws_s3_bucket_public_access_block" "access_block_uploads" { + bucket = aws_s3_bucket.s3bucket_executions.id + + block_public_acls = true + block_public_policy = true + ignore_public_acls = true + restrict_public_buckets = true +} diff --git a/dlp-terraform/ecs/sqs.tf b/dlp-terraform/ecs/sqs.tf new file mode 100644 index 000000000..d5ca20179 --- /dev/null +++ b/dlp-terraform/ecs/sqs.tf @@ -0,0 +1,27 @@ +resource "aws_sqs_queue" "training_queue" { + name = "training-queue.fifo" + fifo_queue = true + message_retention_seconds = 60*24 + + redrive_policy = jsonencode({ + deadLetterTargetArn = aws_sqs_queue.training_queue_deadletter.arn + maxReceiveCount = 4 + }) +} + +resource "aws_sqs_queue" "training_queue_deadletter" { + name = "training-deadletter-queue" +} + +resource "aws_sqs_queue_redrive_allow_policy" "training_queue_redrive_allow_policy" { + queue_url = aws_sqs_queue.training_queue_deadletter.id + + redrive_allow_policy = jsonencode({ + redrivePermission = "byQueue", + sourceQueueArns = [aws_sqs_queue.training_queue.arn] + }) +} + +output "sqs_queue_url" { + value = aws_sqs_queue.training_queue.url +} \ No newline at end of file diff --git a/frontend/next.config.js b/frontend/next.config.js index 089c75cb2..180c1257b 100644 --- a/frontend/next.config.js +++ b/frontend/next.config.js @@ -19,7 +19,7 @@ const nextConfig = { { source: "/api/lambda/:path*", destination: - "https://em9iri9g4j.execute-api.us-west-2.amazonaws.com/:path*", + "https://qt6nzp3sjd.execute-api.us-east-1.amazonaws.com/:path*", }, { source: "/api/training/:path*", diff --git a/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx b/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx index f01623925..ea548121f 100644 --- a/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx +++ b/frontend/src/features/Train/features/Image/components/ImageTrainspace.tsx @@ -19,6 +19,7 @@ import { import { useTrainImageMutation } from "../redux/imageApi"; import { useRouter } from "next/router"; import { removeTrainspaceData } from "@/features/Train/redux/trainspaceSlice"; +import { useCreateTrainspaceMutation } from "@/features/Train/redux/trainspaceApi"; const ImageTrainspace = () => { const trainspace = useAppSelector( @@ -93,6 +94,7 @@ const TrainspaceStepInner = ({ const Component = STEP_SETTINGS[TRAINSPACE_SETTINGS.steps[step]].component; const [isStepModified, setIsStepModified] = useState(false); const [train] = useTrainImageMutation(); + const [createTrainspace] = useCreateTrainspaceMutation(); const[isButtonClicked, setIsButtonClicked] = useState(false); const dispatch = useAppDispatch(); const router = useRouter(); @@ -100,13 +102,17 @@ const TrainspaceStepInner = ({ if (trainspace.step < TRAINSPACE_SETTINGS.steps.length) setStep(trainspace.step); else { - train(trainspace) - .unwrap() - .then(({ trainspaceId }) => { - router.push({ pathname: `/train/${trainspaceId}` }).then(() => { - dispatch(removeTrainspaceData()); - }); + const inner = async () => { + const { trainspaceId } = await createTrainspace(trainspace).unwrap(); + await train({ + trainspaceData: trainspace, + trainspaceId: trainspaceId, + }).unwrap(); + router.push({ pathname: `/train/${trainspaceId}` }).then(() => { + dispatch(removeTrainspaceData()); }); + }; + inner(); } }, [trainspace]); if (!Component) return <>; diff --git a/frontend/src/features/Train/features/Image/redux/imageApi.ts b/frontend/src/features/Train/features/Image/redux/imageApi.ts index 7cd5cfb70..56b64f6bf 100644 --- a/frontend/src/features/Train/features/Image/redux/imageApi.ts +++ b/frontend/src/features/Train/features/Image/redux/imageApi.ts @@ -5,12 +5,13 @@ const imageApi = backendApi.injectEndpoints({ endpoints: (builder) => ({ trainImage: builder.mutation< { trainspaceId: string }, - TrainspaceData<"TRAIN"> + { trainspaceData: TrainspaceData<"TRAIN">; trainspaceId: string } >({ - query: (trainspaceData) => ({ + query: ({ trainspaceData, trainspaceId }) => ({ url: "/api/train/img-run", method: "POST", body: { + trainspace_id: trainspaceId, name: trainspaceData.name, data_source: trainspaceData.dataSource, dataset_data: { diff --git a/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx b/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx index 40aef3e08..ab6adc3a1 100644 --- a/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx +++ b/frontend/src/features/Train/features/Tabular/components/TabularTrainspace.tsx @@ -19,6 +19,7 @@ import { import { useTrainTabularMutation } from "../redux/tabularApi"; import { useRouter } from "next/router"; import { removeTrainspaceData } from "@/features/Train/redux/trainspaceSlice"; +import { useCreateTrainspaceMutation } from "@/features/Train/redux/trainspaceApi"; const TabularTrainspace = () => { const trainspace = useAppSelector( @@ -93,6 +94,7 @@ const TrainspaceStepInner = ({ const Component = STEP_SETTINGS[TRAINSPACE_SETTINGS.steps[step]].component; const [isStepModified, setIsStepModified] = useState(false); const [isButtonClicked, setIsButtonClicked] = useState(false); + const [createTrainspace] = useCreateTrainspaceMutation(); const [train] = useTrainTabularMutation(); const dispatch = useAppDispatch(); const router = useRouter(); @@ -112,16 +114,20 @@ const TrainspaceStepInner = ({ if (trainspace.step < TRAINSPACE_SETTINGS.steps.length) setStep(trainspace.step); else { - train(trainspace) - .unwrap() - .then(({ trainspaceId }) => { - router.push({ pathname: `/train/${trainspaceId}` }).then(() => { - dispatch(removeTrainspaceData()); - }); + const inner = async () => { + const { trainspaceId } = await createTrainspace(trainspace).unwrap(); + await train({ + trainspaceData: trainspace, + trainspaceId: trainspaceId, + }).unwrap(); + router.push({ pathname: `/train/${trainspaceId}` }).then(() => { + dispatch(removeTrainspaceData()); }); + }; + inner(); } }, [trainspace]); - + if (!Component) return null; return ( ({ trainTabular: builder.mutation< { trainspaceId: string }, - TrainspaceData<"TRAIN"> + { trainspaceData: TrainspaceData<"TRAIN">; trainspaceId: string } >({ - query: (trainspaceData) => ({ + query: ({ trainspaceData, trainspaceId }) => ({ url: "/api/training/tabular", method: "POST", body: { + trainspace_id: trainspaceId, name: trainspaceData.name, data_source: trainspaceData.dataSource, target: trainspaceData.parameterData.targetCol, diff --git a/frontend/src/features/Train/redux/trainspaceApi.ts b/frontend/src/features/Train/redux/trainspaceApi.ts index e9dad9959..2f6e6ca8c 100644 --- a/frontend/src/features/Train/redux/trainspaceApi.ts +++ b/frontend/src/features/Train/redux/trainspaceApi.ts @@ -2,9 +2,12 @@ import { backendApi } from "@/common/redux/backendApi"; import { DATA_SOURCE, DatasetData, + DetailedTrainResultsData, FileUploadData, } from "@/features/Train/types/trainTypes"; import { fetchBaseQuery } from "@reduxjs/toolkit/dist/query"; +import { TrainspaceData as TabularTrainspaceData } from "../features/Tabular/types/tabularTypes"; +import { TrainspaceData as ImageTrainspaceData } from "../features/Image/types/imageTypes"; const trainspaceApi = backendApi .enhanceEndpoints({ addTagTypes: ["UserDatasetFilesData"] }) @@ -90,6 +93,37 @@ const trainspaceApi = backendApi return response.data; }, }), + createTrainspace: builder.mutation< + { trainspaceId: string }, + TabularTrainspaceData<"TRAIN"> | ImageTrainspaceData<"TRAIN"> + >({ + query: (trainspaceData) => ({ + url: "/api/lambda/trainspace", + method: "POST", + body: { + name: trainspaceData.name, + data_source: trainspaceData.dataSource, + dataset_data: trainspaceData.datasetData, + review_data: trainspaceData.reviewData, + // TODO: add model_id + }, + }), + }), + getTrainspace: builder.query< + { + config: unknown; + detailedTrainResultsData: DetailedTrainResultsData | undefined; + }, + { trainspaceId: string; withResults: boolean } + >({ + query: ({ trainspaceId, withResults }) => ({ + url: `/api/lambda/trainspace/${trainspaceId}`, + method: "GET", + params: { + with_results: withResults, + }, + }), + }), }), overrideExisting: true, }); @@ -98,4 +132,6 @@ export const { useGetDatasetFilesDataQuery, useUploadDatasetFileMutation, useLazyGetColumnsFromDatasetQuery, + useCreateTrainspaceMutation, + useGetTrainspaceQuery, } = trainspaceApi; diff --git a/frontend/src/features/Train/types/trainTypes.ts b/frontend/src/features/Train/types/trainTypes.ts index 886c796db..d390388c8 100644 --- a/frontend/src/features/Train/types/trainTypes.ts +++ b/frontend/src/features/Train/types/trainTypes.ts @@ -1,5 +1,6 @@ import { DATA_SOURCE_ARR } from "../constants/trainConstants"; +// keep in sync with schemas.py export type DATA_SOURCE = typeof DATA_SOURCE_ARR[number]; export type TRAIN_STATUS = @@ -16,16 +17,61 @@ export interface BaseTrainspaceData { step: number; } +// basic information, used on dashboard export interface TrainResultsData { name: string; - trainspaceId: number; + trainspaceId: string; dataSource: DATA_SOURCE; status: TRAIN_STATUS; created: Date; - step: string; uid: string; } +export type CHART_TYPE = "LINE" | "AUC/ROC" | "CONFUSION_MATRIX" + +export type Chart = TimeSeriesChart | AucRocChart | ConfusionMatrixChart + +export interface TimeSeriesMetric { + x_name: string; + y_name: string; + + x_values: number[]; + y_values: number[]; +} + +export interface TimeSeriesChart { + name: string; + + time_series: TimeSeriesMetric[] + chart_type: "LINE" + graph_index: number; +} + +export interface AucRocChart { + name: string; + + values: [number[], number[], number][]; + + chart_type: "AUC/ROC" + graph_index: number; +} + +export interface ConfusionMatrixChart { + name: string; + + values: number[][]; + + chart_type: "CONFUSION_MATRIX" + graph_index: number; +} + +// more detailed information, used when viewing a run +export interface DetailedTrainResultsData { + basic_info: TrainResultsData + + all_metrics: Chart[] +} + export interface FileUploadData { name: string; lastModified: string; diff --git a/frontend/src/pages/train/[train_space_id].tsx b/frontend/src/pages/train/[train_space_id].tsx index e228cb8b4..f22d507bc 100644 --- a/frontend/src/pages/train/[train_space_id].tsx +++ b/frontend/src/pages/train/[train_space_id].tsx @@ -2,281 +2,83 @@ import Footer from "@/common/components/Footer"; import NavbarMain from "@/common/components/NavBarMain"; import { useAppSelector } from "@/common/redux/hooks"; import { isSignedIn } from "@/common/redux/userLogin"; +import { useGetTrainspaceQuery } from "@/features/Train/redux/trainspaceApi"; +import { DetailedTrainResultsData } from "@/features/Train/types/trainTypes"; import Container from "@mui/material/Container"; import Grid from "@mui/material/Grid"; import Paper from "@mui/material/Paper"; -import dynamic from "next/dynamic"; import { useRouter } from "next/router"; -import { Data, XAxisName, YAxisName } from "plotly.js"; import React, { useEffect } from "react"; -const Plot = dynamic(() => import("react-plotly.js"), { ssr: false }); +import { + mapMetricToLinePlot, + mapMetricToAucRocPlot, + mapMetricToConfusionMatrixPlot, +} from "./metrics_to_charts"; + +const mapTrainResultsDataToCharts = ( + detailedTrainResultsData: DetailedTrainResultsData +) => { + // sort by graph_index asc and ignore negative graph indices + const sortedData = detailedTrainResultsData.all_metrics + .filter((metric) => metric.graph_index >= 0) + .sort((a, b) => a.graph_index - b.graph_index); + const charts = []; + let i = 0; + while (i < sortedData.length) { + const metric = sortedData[i]; + if (metric.chart_type === "LINE") { + charts.push(mapMetricToLinePlot(metric)); + } else if (metric.chart_type === "AUC/ROC") { + charts.push(mapMetricToAucRocPlot(metric)); + } else if (metric.chart_type === "CONFUSION_MATRIX") { + charts.push(mapMetricToConfusionMatrixPlot(metric)); + } else { + throw Error("Undefined chart type received"); + } + i += 1; + } + + return charts; +}; const TrainSpace = () => { const { train_space_id } = useRouter().query; - const data = { - success: true, - message: "Dataset trained and results outputted successfully", - dl_results: [ - { - epoch: 1, - train_time: 0.029964923858642578, - train_loss: 1.1126993695894878, - test_loss: 1.1082043647766113, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 2, - train_time: 0.0221712589263916, - train_loss: 1.1002190907796223, - test_loss: 1.100191593170166, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 3, - train_time: 0.0680840015411377, - train_loss: 1.0896958708763123, - test_loss: 1.0933666229248047, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 4, - train_time: 0.007375478744506836, - train_loss: 1.0802951455116272, - test_loss: 1.0868618488311768, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - { - epoch: 5, - train_time: 0.008754491806030273, - train_loss: 1.071365197499593, - test_loss: 1.080164909362793, - train_acc: 0.3333333333333333, - "val/test acc": 0.3, - }, - ], - auxiliary_outputs: { - confusion_matrix: [ - [0, 0, 6], - [0, 0, 8], - [0, 0, 6], - ], - AUC_ROC_curve_data: [ - [ - [0.0, 0.0, 0.0, 0.07142857142857142, 0.07142857142857142, 1.0], - [ - 0.0, 0.16666666666666666, 0.8333333333333334, 0.8333333333333334, - 1.0, 1.0, - ], - 0.9880952380952381, - ], - [ - [ - 0.0, 0.08333333333333333, 0.5, 0.5, 0.5833333333333334, - 0.5833333333333334, 0.6666666666666666, 0.6666666666666666, 1.0, - ], - [0.0, 0.0, 0.0, 0.75, 0.75, 0.875, 0.875, 1.0, 1.0], - 0.46875, - ], - [ - [0.0, 0.0, 0.0, 0.07142857142857142, 0.07142857142857142, 1.0], - [ - 0.0, 0.16666666666666666, 0.8333333333333334, 0.8333333333333334, - 1.0, 1.0, - ], - 0.9880952380952381, - ], - ], - }, - status: 200, - }; + const { data, isLoading, refetch, error } = useGetTrainspaceQuery({ + trainspaceId: train_space_id, + withResults: true, + }); + const user = useAppSelector((state) => state.currentUser.user); const router = useRouter(); useEffect(() => { if (router.isReady && !user) { + console.log("redirect to login"); router.replace({ pathname: "/login" }); } }, [user, router.isReady]); - if (!isSignedIn(user)) { + + if (error) { + setTimeout(() => refetch(), 3000); + } + + if (!isSignedIn(user) || !data || isLoading) { return <>; } + + const charts = mapTrainResultsDataToCharts( + data.trainspace.detailedTrainResultsData + ); return (

{train_space_id}

- - - x.epoch), - y: data.dl_results.map((x) => x["train_acc"]), - type: "scatter", - mode: "markers", - marker: { color: "red", size: 10 }, - }, - { - name: "Test accuracy", - x: data.dl_results.map((x) => x.epoch), - y: data.dl_results.map((x) => x["val/test acc"]), - type: "scatter", - mode: "markers", - marker: { color: "blue", size: 10 }, - }, - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "Epoch Number" }, - yaxis: { title: "Accuracy" }, - title: "Train vs. Test Accuracy for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - x.epoch), - y: data.dl_results.map((x) => x.train_loss), - type: "scatter", - mode: "markers", - marker: { color: "red", size: 10 }, - }, - { - name: "Test loss", - x: data.dl_results.map((x) => x.epoch), - y: data.dl_results.map((x) => x.test_loss), - type: "scatter", - mode: "markers", - marker: { color: "blue", size: 10 }, - }, - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "Epoch Number" }, - yaxis: { title: "Loss" }, - title: "Train vs. Test Loss for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - ({ - name: `(AUC: ${x[2]})`, - x: x[0] as number[], - y: x[1] as number[], - type: "scatter", - })) as Data[]), - ]} - layout={{ - height: 350, - width: 525, - xaxis: { title: "False Positive Rate" }, - yaxis: { title: "True Positive Rate" }, - title: "AUC/ROC Curves for your Deep Learning Model", - showlegend: true, - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - config={{ responsive: true }} - /> - - - - - - row.map((_, j) => ({ - xref: "x1" as XAxisName, - yref: "y1" as YAxisName, - x: j, - y: - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length, - text: data.auxiliary_outputs.confusion_matrix[ - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length - ][j].toString(), - font: { - color: - data.auxiliary_outputs.confusion_matrix[ - (i + - data.auxiliary_outputs.confusion_matrix.length - - 1) % - data.auxiliary_outputs.confusion_matrix.length - ][j] > 0 - ? "white" - : "black", - }, - showarrow: false, - })) - ) - .flat(), - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - }} - /> - - + {charts.map((chart) => ( + + {chart} + + ))}