-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdemo.py
90 lines (78 loc) · 2.26 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import imageio as imageio
import torch
import os
import satellighte as sat
import time
def parse_arguments():
"""
Parse command line arguments.
Returns: Parsed arguments
"""
arg = argparse.ArgumentParser()
arg.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cuda", "cpu"],
help="GPU device to use",
)
arg.add_argument(
"--model_name",
type=str,
default=sat.available_models()[0],
choices=sat.available_models(),
help="Model architecture",
)
arg.add_argument(
"--version",
type=str,
help="Model version",
)
arg.add_argument(
"--source",
"-s",
type=str,
required=True,
help="Path to the image file or directory",
)
return arg.parse_args()
def main(args):
"""
Main function.
Args:
args : Parsed arguments
"""
if args.version:
if args.version not in sat.get_model_versions(args.model_name):
raise ValueError(
f"model version {args.version} not available for model {args.model_name}, available versions: {sat.get_model_versions(args.model_name)}"
)
version = args.version
else:
version = sat.get_model_latest_version(args.model_name)
model = sat.Classifier.from_pretrained(args.model_name, version=version)
model.eval()
model.to(args.device)
model.summarize()
if os.path.isdir(args.source):
for file in os.listdir(args.source):
print(file)
file_path = os.path.join(args.source, file)
if os.path.isfile(file_path):
img = imageio.imread(file_path)
results = model.predict(img)
pil_img = sat.utils.visualize(img, results)
pil_img.show()
print(results)
time.sleep(1)
else:
if os.path.isfile(args.source):
img = imageio.imread(args.source)
results = model.predict(img)
pil_img = sat.utils.visualize(img, results)
pil_img.show()
print(results)
if __name__ == "__main__":
pa = parse_arguments()
main(pa)