Skip to content

Commit 65c8143

Browse files
Sandbox run src/api.py
1 parent a5d9947 commit 65c8143

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/api.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
import torch
2-
from cnn import CNN # Importing CNN class from cnn.py
32
from fastapi import FastAPI, File, UploadFile
43
from PIL import Image
54
from torchvision import transforms
65

6+
from cnn import CNN # Importing CNN class from cnn.py
7+
78
# Load the model
89
model = CNN()
910
model.load_state_dict(torch.load("mnist_model.pth"))
1011
model.eval()
1112

1213
# Transform used for preprocessing the image
13-
transform = transforms.Compose([
14-
transforms.ToTensor(),
15-
transforms.Normalize((0.5,), (0.5,))
16-
])
14+
transform = transforms.Compose(
15+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
16+
)
1717

1818
app = FastAPI()
1919

20+
2021
@app.post("/predict/")
2122
async def predict(file: UploadFile = File(...)):
2223
image = Image.open(file.file).convert("L")

0 commit comments

Comments
 (0)