Skip to content

Commit 294dd61

Browse files
authored
Merge pull request #19 from delta-mpc/optimize
Optimize
2 parents 86ef8b0 + 80f56e6 commit 294dd61

File tree

6 files changed

+87
-34
lines changed

6 files changed

+87
-34
lines changed

delta/dataset/dataset.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def __init__(
3131

3232

3333
class FileDataset(TorchDataset):
34-
def __init__(self, filename: str) -> None:
35-
result = load_file(filename)
34+
def __init__(self, filename: str, **kwargs) -> None:
35+
result = load_file(filename, **kwargs)
3636
if isinstance(result, Image.Image):
3737
raise ValueError("file dataset does not support image file")
3838
self._result = result
@@ -50,9 +50,10 @@ def __len__(self) -> int:
5050

5151

5252
class DirectoryDataset(TorchDataset):
53-
def __init__(self, directory: str) -> None:
53+
def __init__(self, directory: str, **kwargs) -> None:
5454
self._xs = []
5555
self._ys = []
56+
self._kwargs = kwargs
5657
root, dirnames, filenames = next(os.walk(directory))
5758
if len(filenames) > 0 and len(dirnames) == 0:
5859
self._xs.extend([os.path.join(root, filename) for filename in filenames])
@@ -77,7 +78,7 @@ def __init__(self, directory: str) -> None:
7778

7879
def __getitem__(self, index):
7980
filename = self._xs[index]
80-
x = load_file(filename)
81+
x = load_file(filename, **self._kwargs)
8182
y = None
8283
if len(self._ys) > 0:
8384
y = self._ys[index]
@@ -121,12 +122,12 @@ def split_dataset(
121122

122123

123124
def load_dataset(
124-
dataset_name: str,
125+
dataset_name: str, **kwargs
125126
) -> TorchDataset | Tuple[TorchDataset, TorchDataset]:
126127
if not os.path.exists(dataset_name):
127128
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), dataset_name)
128129
if os.path.isfile(dataset_name):
129-
dataset = FileDataset(dataset_name)
130+
dataset = FileDataset(dataset_name, **kwargs)
130131
return dataset
131132
else:
132133
train_path = os.path.join(dataset_name, "train")
@@ -140,12 +141,12 @@ def load_dataset(
140141
train_root, _, train_files = next(os.walk(train_path))
141142
val_root, _, val_files = next(os.walk(val_path))
142143
if len(train_files) == 1 and len(val_files) == 1:
143-
train_dataset = FileDataset(os.path.join(train_root, train_files[0]))
144-
val_dataset = FileDataset(os.path.join(val_root, val_files[0]))
144+
train_dataset = FileDataset(os.path.join(train_root, train_files[0]), **kwargs)
145+
val_dataset = FileDataset(os.path.join(val_root, val_files[0]), **kwargs)
145146
else:
146-
train_dataset = DirectoryDataset(train_path)
147-
val_dataset = DirectoryDataset(val_path)
147+
train_dataset = DirectoryDataset(train_path, **kwargs)
148+
val_dataset = DirectoryDataset(val_path, **kwargs)
148149
return train_dataset, val_dataset
149150
else:
150-
dataset = DirectoryDataset(dataset_name)
151+
dataset = DirectoryDataset(dataset_name, **kwargs)
151152
return dataset

delta/dataset/file.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def load_file(
3939
result = pd.read_csv(filename, sep=r"\s+", **kwargs)
4040
elif filename.endswith(".xls") or filename.endswith(".xlsx"):
4141
result = pd.read_excel(filename, **kwargs)
42+
elif filename.endswith(".json"):
43+
result = pd.read_json(filename, **kwargs)
4244
else:
4345
try:
4446
result = Image.open(filename, **kwargs)

delta/delta_node.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,25 @@ def __init__(self, url: str) -> None:
2323

2424
def create_task(self, task: Task) -> int:
2525
url = f"{self._url}/v1/task"
26-
with TemporaryFile(mode="w+b") as file:
27-
serialize.dump_task(file, task)
28-
file.seek(0)
29-
resp = httpx.post(url, files={"file": file}, timeout=None)
26+
with TemporaryFile(mode="w+b") as task_file, TemporaryFile(mode="w+b") as config_file:
27+
task_config = {
28+
"name": task.name,
29+
"dataset": task.dataset,
30+
"type": task.type,
31+
"enable_verify": task.enable_verify,
32+
"options": task.options
33+
}
34+
pickle.dump(task_config, config_file)
35+
config_file.seek(0)
36+
37+
serialize.dump_task(task_file, task)
38+
task_file.seek(0)
39+
files = {
40+
"file": ("task_file.pkl", task_file, "application/octet-stream"),
41+
"config": ("task_config_file.pkl", config_file, "application/pickle")
42+
}
43+
44+
resp = httpx.post(url, files=files, timeout=None)
3045
resp.raise_for_status()
3146
data = resp.json()
3247
task_id = data["task_id"]

delta/pandas/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def map(self, data: pandas.DataFrame) -> pandas.Series:
9191

9292
def _dispatch_binary_op(
9393
self,
94-
other: "DataFrame" | "Series" | List[float] | float,
94+
other: "DataFrame | Series | List[float] | float",
9595
op_name: str,
9696
op: Callable[..., Any],
9797
**kwargs: Any,

delta/task/learning.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,27 +121,55 @@ def __init__(
121121
self.epoch = epoch
122122
self.iteration = iteration
123123
self.strategy = strategy
124+
self.batch_sampler = None
124125

125126
def __iter__(self):
126127
return self._get_iter()
127128

129+
def _make_dataloader(self) -> DataLoader:
130+
if self.batch_sampler is None:
131+
return self.dataloader
132+
else:
133+
return DataLoader(
134+
dataset=self.dataloader.dataset,
135+
batch_sampler=self.batch_sampler,
136+
num_workers=self.dataloader.num_workers,
137+
collate_fn=self.dataloader.collate_fn,
138+
pin_memory=self.dataloader.pin_memory,
139+
timeout=self.dataloader.timeout,
140+
worker_init_fn=self.dataloader.worker_init_fn,
141+
multiprocessing_context=self.dataloader.multiprocessing_context,
142+
generator=self.dataloader.generator,
143+
prefetch_factor=self.dataloader.prefetch_factor,
144+
persistent_workers=self.dataloader.persistent_workers,
145+
pin_memory_device=self.dataloader.pin_memory_device,
146+
)
147+
128148
def _get_iter(self):
129149
finished = False
130150

131151
while not finished:
132-
for batch in self.dataloader:
152+
count = 0
153+
dataloader = self._make_dataloader()
154+
for batch in dataloader:
133155
if finished:
134156
break
135157

136158
_logger.info(f"Training epoch {self.epoch} iteration {self.iteration}")
137159

138160
yield batch
139-
161+
162+
count += 1
140163
if self.strategy.should_merge(self.epoch, self.iteration, False):
141164
_logger.info(f"iteration {self.iteration}, start to merge")
165+
assert dataloader.batch_sampler is not None
166+
if self.batch_sampler is None:
167+
self.batch_sampler = list(dataloader.batch_sampler)
168+
self.batch_sampler = self.batch_sampler[count:]
142169
finished = True
143170
self.iteration += 1
144-
171+
172+
self.batch_sampler = None
145173
if self.strategy.should_merge(self.epoch, self.iteration, True):
146174
_logger.info(f"epoch {self.epoch}, start to merge")
147175
finished = True
@@ -372,9 +400,14 @@ def map(
372400
epoch: int,
373401
iteration: int,
374402
) -> Tuple[Dict[str, np.ndarray], int, int]:
375-
self.learning.strategy.weight_to_params(
376-
weight, self.learning.state_dict()
377-
)
403+
if len(weight) > 0:
404+
self.learning.strategy.weight_to_params(
405+
weight, self.learning.state_dict()
406+
)
407+
else:
408+
weight = self.learning.strategy.params_to_weight(
409+
self.learning.state_dict()
410+
)
378411
_logger.info(f"Round {self.round} training")
379412
train_iter = TrainIterator(
380413
dataloader, epoch, iteration, self.learning.strategy
@@ -517,7 +550,10 @@ def reduce(
517550
self.learning.strategy.weight_to_params(
518551
weight, self.learning.state_dict()
519552
)
520-
return self.learning.state_dict()
553+
res: Dict[str, Any] = {"weight": self.learning.state_dict()}
554+
if metrics is not None:
555+
res["metrics"] = metrics
556+
return res
521557

522558
input_nodes: List[DataNode] = [weight_node]
523559
if metrics_node is not None:
@@ -541,9 +577,8 @@ def _build_graph(self) -> Tuple[List[delta.dataset.Dataset], List[GraphNode]]:
541577
iteration_node = InputGraphNode(
542578
name="iteration", location=DataLocation.CLIENT, default=1
543579
)
544-
weight_arr = self.strategy.params_to_weight(self.state_dict())
545580
weight_node = InputGraphNode(
546-
name="weight_0", location=DataLocation.SERVER, default=weight_arr
581+
name="weight_0", location=DataLocation.SERVER, default=np.empty(0)
547582
)
548583
metrics_node = None
549584
inputs = [dataset_node, epoch_node, iteration_node, weight_node]

setup.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,20 @@ def run_tests(self):
3030

3131
setup(
3232
name="delta-task",
33-
version="0.8.3",
33+
version="0.8.4rc1",
3434
license_files=("LICENSE"),
3535
packages=find_packages(),
3636
include_package_data=True,
3737
exclude_package_data={"": [".gitignore"]},
3838
install_requires=[
39-
"cloudpickle==1.6.0",
40-
"httpx==0.23.0",
41-
"numpy==1.22.0",
42-
"Pillow==9.1.1",
43-
"pandas==1.2.3",
44-
"pytest==6.2.5",
45-
"torch==1.8.2+cpu",
46-
"networkx==2.7.1"
39+
"cloudpickle>=1.6.0",
40+
"httpx>=0.23.0",
41+
"numpy>=1.22.0",
42+
"Pillow>=9.1.1",
43+
"pandas>=1.2.3",
44+
"pytest>=6.2.5",
45+
"torch>=1.8.2",
46+
"networkx>=2.7.1"
4747
],
4848
tests_require=["pytest"],
4949
cmdclass={"test": PyTest},

0 commit comments

Comments
 (0)