Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki committed Jun 8, 2022
1 parent d324b97 commit cf9dd80
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
18 changes: 11 additions & 7 deletions dali/test/python/test_operator_arithmetic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,11 @@ def test_arithmetic_ops():
def test_ternary_ops_big():
for kinds in selected_ternary_input_kinds:
for (op, op_desc) in ternary_operations:
for types_in in [(np.int32, np.int32, np.int32), (np.int32, np.int8, np.int16),
(np.int32, np.uint8, np.float32)]:
for types_in in [
(np.int32, np.int32, np.int32),
(np.int32, np.int8, np.int16),
(np.int32, np.uint8, np.float32),
]:
yield check_ternary_op, kinds, types_in, op, shape_big, op_desc


Expand All @@ -565,9 +568,12 @@ def test_ternary_ops_selected():
def test_ternary_ops_kinds():
for kinds in ternary_input_kinds:
for (op, op_desc) in ternary_operations:
for types_in in [(np.int32, np.int32, np.int32), (np.float32, np.int32, np.int32),
(np.uint8, np.float32, np.float32),
(np.int32, np.float32, np.float32)]:
for types_in in [
(np.int32, np.int32, np.int32),
(np.float32, np.int32, np.int32),
(np.uint8, np.float32, np.float32),
(np.int32, np.float32, np.float32),
]:
yield check_ternary_op, kinds, types_in, op, shape_small, op_desc


Expand Down Expand Up @@ -600,8 +606,6 @@ def test_bitwise_ops():

def check_comparsion_op(kinds, types, op, shape, _):
# Comparisons - should always return bool
left_type, right_type = types
left_kind, right_kind = kinds
iterator = iter(ExternalInputIterator(batch_size, shape, types, kinds))
pipe = ExprOpPipeline(kinds, types, iterator, op, batch_size=batch_size, num_threads=2,
device_id=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_return_empty():
extract_dir = generate_temp_extract(tar_file_path)
equivalent_files = glob(extract_dir.name + "/*")
equivalent_files = sorted(equivalent_files,
key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))
key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]))) # noqa: 203

compare_pipelines(
webdataset_raw_pipeline(
Expand Down Expand Up @@ -62,9 +62,9 @@ def test_skip_sample():
extract_dir = generate_temp_extract(tar_file_path)
equivalent_files = list(
filter(
lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) < 2500,
lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) < 2500, # noqa: 203
sorted(glob(extract_dir.name + "/*"),
key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])),
key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])), # noqa: 203
))

compare_pipelines(
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_different_components():
extract_dir = generate_temp_extract(tar_file_path)
equivalent_files = glob(extract_dir.name + "/*")
equivalent_files = sorted(equivalent_files,
key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))
key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]))) # noqa: 203

compare_pipelines(
webdataset_raw_pipeline(
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_wds_sharding():
equivalent_files = sum(
list(
sorted(glob(extract_dir.name +
"/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
"/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203
for extract_dir in extract_dirs),
[],
)
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_sharding():

extract_dir = generate_temp_extract(tar_file_path)
equivalent_files = sorted(glob(extract_dir.name + "/*"),
key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203

num_shards = 100
for shard_id in range(num_shards):
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_index_generation():
equivalent_files = sum(
list(
sorted(glob(extract_dir.name +
"/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
"/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203
for extract_dir in extract_dirs),
[],
)
Expand Down

0 comments on commit cf9dd80

Please sign in to comment.