Skip to content

Commit 2a5a0f2

Browse files
committed
Add pipit
1 parent 5027bb7 commit 2a5a0f2

File tree

9 files changed

+34
-12
lines changed

9 files changed

+34
-12
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,21 @@
77
Pre-trained models based on TensorLayerX.
88
TensorLayerX is a multi-backend AI framework, which can run on almost all operation systems and AI hardwares, and support hybrid-framework programming. The currently version supports TensorFlow, MindSpore, PaddlePaddle, PyTorch, OneFlow and Jittor as the backends.
99

10+
# Quick Start
11+
## Installation
12+
### Via pip
13+
```bash
14+
# install from pypi
15+
pip3 install tlxzoo
16+
```
17+
18+
## train
19+
```bash
20+
python demo/vision/image_classification/vgg/train.py
21+
```
22+
23+
## predict
24+
25+
```bash
26+
python demo/vision/image_classification/vgg/predict.py
27+
```

demo/vision/face_recognition/retinaface/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def tf_train(
8989
root_path="./wider/widerface",
9090
train_ann_path="label.txt",
9191
val_ann_path="label.txt",
92-
num_workers=2,
92+
num_workers=0,
9393
)
9494
transform = RetinaFaceTransform()
9595
wider.register_transform_hook(transform)

demo/vision/human_pose_estimation/hrnet/predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
model.load_weights("./demo/vision/human_pose_estimation/hrnet/model.npz")
1313
model.set_eval()
1414

15-
path = "./demo/vision/human_pose_estimation/hrnet/hrnet.jpg"
15+
path = "./coco2017/0.1/val2017/000000527784.jpg"
1616
image = Image.open(path).convert('RGB')
1717
image_height, image_width = image.height, image.width
1818
image = np.array(image, dtype=np.float32)

demo/vision/human_pose_estimation/hrnet/train.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def tf_train(
6464

6565
if (epoch + 1) % 5 == 0:
6666
valid(network, test_dataset)
67-
model.save_weights("./demo/vision/human_pose_estimation/hrnet/model.npz")
67+
model.save_weights("./demo/vision/human_pose_estimation/model.npz")
6868

6969
optimizer.lr.step()
7070

@@ -75,10 +75,10 @@ def __init__(self, learning_rate, last_epoch=0, verbose=False):
7575

7676
def get_lr(self):
7777

78-
if int(self.last_epoch) >= 40:
78+
if int(self.last_epoch) >= 65:
7979
return self.base_lr * 0.01
8080

81-
if int(self.last_epoch) >= 30:
81+
if int(self.last_epoch) >= 40:
8282
return self.base_lr * 0.1
8383

8484
return self.base_lr
@@ -91,7 +91,7 @@ def get_lr(self):
9191
data_name="Coco",
9292
train_ann_path="./annotations/person_keypoints_train2017.json",
9393
val_ann_path="./annotations/person_keypoints_val2017.json",
94-
num_workers=4)
94+
num_workers=0)
9595

9696
transform = HRNetTransform()
9797
datasets.register_transform_hook(transform)
@@ -103,7 +103,7 @@ def get_lr(self):
103103
# optimizer = tlx.optimizers.SGD(lr=scheduler)
104104

105105
trainer = Trainer(network=model, loss_fn=model.loss_fn, optimizer=optimizer, metrics=None)
106-
trainer.train(n_epoch=50, train_dataset=datasets.train, test_dataset=datasets.test, print_freq=1,
106+
trainer.train(n_epoch=80, train_dataset=datasets.train, test_dataset=datasets.test, print_freq=1,
107107
print_train_batch=True)
108108

109-
model.save_weights("./demo/vision/human_pose_estimation/hrnet/model.npz")
109+
model.save_weights("./demo/vision/human_pose_estimation/model.npz")

demo/vision/image_classification/vgg/predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
model.load_weights("./demo/vision/image_classification/vgg/model.npz")
2323
model.set_eval()
2424

25-
image = cv2.imread("dog.png")
25+
image = cv2.imread("./demo/vision/image_classification/vgg/dog.png")
2626
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
2727
image = cv2.resize(image, (32, 32))
2828

docker/Dockerfile

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ RUN pip install pandas
88
RUN pip install jupyterlab
99
RUN pip install opencv-python
1010
RUN pip install SoundFile
11+
RUN apt-get install libsndfile1
1112
RUN pip install sentencepiece
1213
RUN pip install sacrebleu
1314
RUN pip install rouge_score

setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[metadata]
2+
desciption-file = README.md

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def req_file(filename, folder="requirements"):
4141
long_description_content_type="text/markdown",
4242
author="tensorlayerx",
4343
author_email="",
44-
url="",
44+
packages=find_packages(),
45+
url="https://github.com/tensorlayer/TLXZoo",
4546
keywords="tensorlayerx zoo",
4647
python_requires=">=3.5",
4748
install_requires=install_requires,

tlxzoo/datasets/data_loader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class DataLoaders(object):
88
def __init__(self, data_name, train_limit=None, collate_fn=None, transform_hook=None, transform_hook_index=None,
9-
num_workers=2, per_device_train_batch_size=2, per_device_eval_batch_size=2, **kwargs):
9+
num_workers=0, per_device_train_batch_size=2, per_device_eval_batch_size=2, **kwargs):
1010
self.dataset_dict = Registers.datasets[data_name].load(train_limit=train_limit, **kwargs)
1111

1212
if "train" in self.dataset_dict:
@@ -68,7 +68,7 @@ def register_transform_hook(self, transform_hook, index=None):
6868
transform_hook.set_eval()
6969
self.dataset_dict["test"].register_transform_hook(transform_hook, index=index)
7070

71-
def dataset_dataloader(self, dataset, dataset_type="train", num_workers=8, collate_fn=None,
71+
def dataset_dataloader(self, dataset, dataset_type="train", num_workers=0, collate_fn=None,
7272
per_device_train_batch_size=2, per_device_eval_batch_size=2):
7373

7474
if dataset_type == "train":

0 commit comments

Comments
 (0)