-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAPI.py
More file actions
102 lines (78 loc) · 2.82 KB
/
API.py
File metadata and controls
102 lines (78 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# way to upload image: endpoint
# way to save the image
# function to make prediction on the image
# show the results
import os
import torch
import albumentations
import pretrainedmodels
import numpy as np
import torch.nn as nn
from flask import Flask
from flask import render_template
from flask import request
from torch.nn import functional as F
from wtfml.data_loaders.image.loader import ClassificationLoader
from wtfml.engine import Engine
app=Flask(__name__)
UPLOAD_FOLDER = "C:/Users/bulig/PycharmProjects/pythonProject/static"
DEVICE = "cpu"
MODEL = None
class SEResnext50_32x4d(nn.Module):
def __init__(self, pretrained='imagenet'):
super(SEResnext50_32x4d, self).__init__()
self.base_model = pretrainedmodels.__dict__[
"se_resnext50_32x4d"
](pretrained=pretrained)
self.l0 = nn.Linear(2048, 1)
def forward(self, image, targets):
bs, _, _, _ = image.shape
x = self.base_model.features(image)
x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
out = torch.sigmoid(self.l0(x))
loss = 0
return out, loss
def predict(image_path,model):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
test_aug = albumentations.Compose(
[
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)
]
)
test_images = [image_path]
test_targets = [0]
test_dataset = ClassificationLoader(
image_paths=test_images,
targets=test_targets,
resize=None,
augmentations=test_aug,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=0
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
eng = Engine(model,optimizer=optimizer, device=DEVICE)
predictions = eng.predict(test_loader)
return np.vstack((predictions)).ravel()
@app.route("/",methods=["GET","POST"])
def upload_predict():
if request.method == "POST":
image_file=request.files["image"]
if image_file:
image_location=os.path.join(
UPLOAD_FOLDER,
image_file.filename
)
image_file.save(image_location)
pred = predict(image_location,MODEL)[0]
return render_template("index.html",prediction=pred, image_loc=image_file.filename)
return render_template("index.html",prediction=0, image_loc=None)
if __name__ == "__main__":
MODEL = SEResnext50_32x4d(pretrained=None)
MODEL.load_state_dict(torch.load("C:/Users/bulig/PycharmProjects/pythonProject/input/model_fold_4.bin",map_location=torch.device(DEVICE)))
MODEL.to(DEVICE)
app.run(host="0.0.0.0",port=12000,debug=True)