Skip to content

Commit 36b1800

Browse files
authored
jit_apply_function (pytorch#7)
1 parent 41ecf3c commit 36b1800

File tree

8 files changed

+620
-95
lines changed

8 files changed

+620
-95
lines changed

.clang-format

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
---
2+
AccessModifierOffset: -1
3+
AlignAfterOpenBracket: AlwaysBreak
4+
AlignConsecutiveAssignments: false
5+
AlignConsecutiveDeclarations: false
6+
AlignEscapedNewlinesLeft: true
7+
AlignOperands: false
8+
AlignTrailingComments: false
9+
AllowAllParametersOfDeclarationOnNextLine: false
10+
AllowShortBlocksOnASingleLine: false
11+
AllowShortCaseLabelsOnASingleLine: false
12+
AllowShortFunctionsOnASingleLine: Empty
13+
AllowShortIfStatementsOnASingleLine: false
14+
AllowShortLoopsOnASingleLine: false
15+
AlwaysBreakAfterReturnType: None
16+
AlwaysBreakBeforeMultilineStrings: true
17+
AlwaysBreakTemplateDeclarations: true
18+
BinPackArguments: false
19+
BinPackParameters: false
20+
BraceWrapping:
21+
AfterClass: false
22+
AfterControlStatement: false
23+
AfterEnum: false
24+
AfterFunction: false
25+
AfterNamespace: false
26+
AfterObjCDeclaration: false
27+
AfterStruct: false
28+
AfterUnion: false
29+
BeforeCatch: false
30+
BeforeElse: false
31+
IndentBraces: false
32+
BreakBeforeBinaryOperators: None
33+
BreakBeforeBraces: Attach
34+
BreakBeforeTernaryOperators: true
35+
BreakConstructorInitializersBeforeComma: false
36+
BreakAfterJavaFieldAnnotations: false
37+
BreakStringLiterals: false
38+
ColumnLimit: 80
39+
CommentPragmas: '^ IWYU pragma:'
40+
CompactNamespaces: false
41+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
42+
ConstructorInitializerIndentWidth: 4
43+
ContinuationIndentWidth: 4
44+
Cpp11BracedListStyle: true
45+
DerivePointerAlignment: false
46+
DisableFormat: false
47+
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
48+
IncludeCategories:
49+
- Regex: '^<.*\.h(pp)?>'
50+
Priority: 1
51+
- Regex: '^<.*'
52+
Priority: 2
53+
- Regex: '.*'
54+
Priority: 3
55+
IndentCaseLabels: true
56+
IndentWidth: 2
57+
IndentWrappedFunctionNames: false
58+
KeepEmptyLinesAtTheStartOfBlocks: false
59+
MacroBlockBegin: ''
60+
MacroBlockEnd: ''
61+
MaxEmptyLinesToKeep: 1
62+
NamespaceIndentation: None
63+
ObjCBlockIndentWidth: 2
64+
ObjCSpaceAfterProperty: false
65+
ObjCSpaceBeforeProtocolList: false
66+
PenaltyBreakBeforeFirstCallParameter: 1
67+
PenaltyBreakComment: 300
68+
PenaltyBreakFirstLessLess: 120
69+
PenaltyBreakString: 1000
70+
PenaltyExcessCharacter: 1000000
71+
PenaltyReturnTypeOnItsOwnLine: 2000000
72+
PointerAlignment: Left
73+
ReflowComments: true
74+
SortIncludes: true
75+
SpaceAfterCStyleCast: false
76+
SpaceBeforeAssignmentOperators: true
77+
SpaceBeforeParens: ControlStatements
78+
SpaceInEmptyParentheses: false
79+
SpacesBeforeTrailingComments: 1
80+
SpacesInAngles: false
81+
SpacesInContainerLiterals: true
82+
SpacesInCStyleCastParentheses: false
83+
SpacesInParentheses: false
84+
SpacesInSquareBrackets: false
85+
Standard: Cpp11
86+
TabWidth: 8
87+
UseTab: Never
88+
...

benchmarks/basic.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from nestedtensor import torch
2+
import utils
3+
4+
import random
5+
6+
7+
def gen_list_nested_tensor_construction():
8+
tensors = [torch.rand(random.randint(500, 1500), 25600) for _ in range(20)]
9+
def _algorithm():
10+
nt = torch._ListNestedTensor(tensors)
11+
return _algorithm
12+
13+
def gen_list_nested_tensor_unbind():
14+
nested_tensor = torch._ListNestedTensor([torch.rand(random.randint(500, 1500), 25600) for _ in range(20)])
15+
def _algorithm():
16+
ts = nested_tensor.unbind()
17+
return _algorithm
18+
19+
if __name__ == "__main__":
20+
# print(utils.benchmark_fn(alg, use_cprofile=True))
21+
# alg = gen_list_nested_tensor_construction()
22+
# print(utils.benchmark_fn(alg))
23+
alg = gen_list_nested_tensor_unbind()
24+
print(utils.benchmark_fn(alg))

benchmarks/jit_apply.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from nestedtensor import torch
2+
import nestedtensor
3+
import utils
4+
5+
6+
def vmap(fn):
7+
def decorator(arg):
8+
if torch.is_tensor(arg):
9+
return fn(arg)
10+
else:
11+
def asd(x):
12+
return fn(x)
13+
return arg.jit_apply(torch.jit.script(asd))
14+
return decorator
15+
16+
17+
@torch.jit.script
18+
def my_fun(x):
19+
x = x + 1
20+
y = x.abs()
21+
return y
22+
23+
# print(e)
24+
25+
26+
def gen_current():
27+
n = torch.as_nested_tensor(
28+
[torch.randn(256, 128).to(device='cuda') for _ in range(128)])
29+
30+
def _algorithm():
31+
n1 = n + 1
32+
n1.abs()
33+
34+
return _algorithm
35+
36+
37+
def gen_jit():
38+
39+
n = nestedtensor._ListNestedTensor(
40+
[torch.randn(256, 128).to(device='cuda') for _ in range(128)])
41+
42+
def gen_my_fun(scalar, tensor):
43+
@torch.jit.ignore
44+
def get_scalar() -> float:
45+
return scalar
46+
47+
@torch.jit.ignore
48+
def get_tensor() -> torch.Tensor:
49+
return tensor
50+
51+
@torch.jit.script
52+
def my_fun(x, y):
53+
x = x + get_scalar()
54+
x = x + get_tensor()
55+
y = y + x.abs()
56+
return y
57+
return my_fun
58+
my_fun = gen_my_fun(3.0, torch.randn(1).to(device='cuda'))
59+
60+
def _algorithm():
61+
nestedtensor._C.jit_apply_function((n, n), my_fun)
62+
63+
return _algorithm
64+
65+
66+
if __name__ == "__main__":
67+
# print(utils.benchmark_fn(alg, use_cprofile=True))
68+
# alg = gen_list_nested_tensor_construction()
69+
# print(utils.benchmark_fn(alg))
70+
alg1 = gen_current()
71+
print(utils.benchmark_fn(alg1))
72+
alg2 = gen_jit()
73+
print(utils.benchmark_fn(alg2))

benchmarks/nearest_neighbors.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from nestedtensor import torch
2+
import nestedtensor
3+
import argparse
4+
import time
5+
import random
6+
import pprint
7+
8+
EMBED_DIM = 1024
9+
10+
SEED = 0
11+
12+
13+
def gen_tensor():
14+
globals()['SEED'] += 1
15+
# return torch.tensor([globals()['SEED']])
16+
return torch.rand(EMBED_DIM).to(device='cuda')
17+
18+
19+
def gen_clusters(num_clusters, size_range):
20+
21+
def gen_cluster(num_entries):
22+
return [gen_tensor() for _ in range(num_entries)]
23+
24+
return [gen_cluster(random.randint(*size_range)) for _ in range(num_clusters)]
25+
26+
27+
def gen_algorithm_naive(keys, sub_clusters):
28+
# For-loops over vectors
29+
def _naive():
30+
results = []
31+
for sub_cluster, key in zip(sub_clusters, keys):
32+
sub_cluster_results = []
33+
for cluster in sub_cluster:
34+
sub_cluster_results.append(
35+
[torch.dot(key, entry).item() for entry in cluster])
36+
results.append(sub_cluster_results)
37+
return results
38+
return _naive
39+
40+
def gen_algorithm_mv(keys, sub_clusters):
41+
# For-loops over vectors and matrices
42+
new_sub_clusters = []
43+
for sub_cluster in sub_clusters:
44+
new_sub_cluster = [torch.stack(cluster) for cluster in sub_cluster]
45+
new_sub_clusters.append(new_sub_cluster)
46+
sub_clusters = new_sub_clusters
47+
def _mv():
48+
results = []
49+
for sub_cluster, key in zip(sub_clusters, keys):
50+
sub_cluster_results = []
51+
for cluster in sub_cluster:
52+
sub_cluster_results.append(torch.mv(cluster, key))
53+
results.append(sub_cluster_results)
54+
return results
55+
return _mv
56+
57+
def gen_algorithm_nested_mv(keys, sub_clusters):
58+
# For-loops over vectors and matrices
59+
new_sub_clusters = []
60+
for sub_cluster in sub_clusters:
61+
new_sub_cluster = [torch.tensor(list(map(list, cluster))) for cluster in sub_cluster]
62+
new_sub_clusters.append(new_sub_cluster)
63+
nested_sub_clusters = torch.nested_tensor(sub_clusters).to_tensor(2)
64+
nested_keys = torch.nested_tensor(keys)
65+
def _nested_mv():
66+
return torch.mv(nested_sub_clusters, nested_keys)
67+
return _nested_mv
68+
69+
def gen_algorithm_nested_jit_mv(keys, sub_clusters):
70+
# For-loops over vectors and matrices
71+
new_sub_clusters = []
72+
for sub_cluster in sub_clusters:
73+
new_sub_cluster = []
74+
for cluster in sub_cluster:
75+
new_sub_cluster.append(torch.stack(cluster))
76+
new_sub_clusters.append(new_sub_cluster)
77+
nested_sub_clusters = nestedtensor._ListNestedTensor(new_sub_clusters)
78+
print("HERE")
79+
print(nested_sub_clusters.nested_size())
80+
nested_keys = nestedtensor._ListNestedTensor(keys)
81+
print(nested_keys.nested_size())
82+
83+
@torch.jit.script
84+
def my_fun(x, y):
85+
return torch.mv(x, y)
86+
87+
def _nested_jit_mv():
88+
return nestedtensor._C.jit_apply_function((nested_sub_clusters, nested_keys), my_fun)
89+
return _nested_jit_mv
90+
91+
92+
def print_results(results, keys, sub_clusters, print_details=False):
93+
if print_details:
94+
for i, sub_cluster in enumerate(sub_clusters):
95+
print("\n\u001b[31msub cluster {} count {} total number of entries {}\u001b[0m".format(
96+
i, len(sub_cluster), sum(map(len, sub_cluster))))
97+
pprint.pprint(sub_cluster)
98+
print("\nkeys")
99+
pprint.pprint(keys)
100+
print("")
101+
102+
for i, result in enumerate(results):
103+
print(
104+
"result scores for \u001b[31msub cluster {} and key {}\u001b[0m".format(i, i))
105+
pprint.pprint(result)
106+
107+
def benchmark_fn(fn, run_time = 15.0):
108+
times = []
109+
num_runs = 0
110+
fn()
111+
t = 0.0
112+
while (t < run_time):
113+
ti = time.time()
114+
fn()
115+
torch.cuda.synchronize()
116+
ti = time.time() - ti
117+
t += ti
118+
times.append(ti)
119+
times = torch.tensor(times) * 1e6
120+
return "fn {:<15} avg(us): {:10.4f} std(us): {:10.4f} num_runs: {}".format(fn.__name__, times.mean().item(), times.std().item(), len(times))
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser()
125+
parser.add_argument('--print-results', dest='print_results', action='store_true')
126+
args = parser.parse_args()
127+
# NOTE: This dodging creating these subclusters from a single set of clusters
128+
# This additional memory pressure might be crucial
129+
keys = [gen_tensor()] * 16
130+
clusters = gen_clusters(16, (16,16))
131+
sub_clusters = [[clusters[random.randint(0, 15)]] * 8 for _ in range(16)]
132+
133+
# Two keys for now
134+
# Simulating some overlap
135+
136+
sub_clusters = [clusters[:3], clusters[2:]]
137+
138+
# Get algorithm
139+
gen_results_naive = gen_algorithm_naive(keys, sub_clusters)
140+
gen_results_mv = gen_algorithm_mv(keys, sub_clusters)
141+
gen_results_nested_mv = gen_algorithm_nested_mv(keys, sub_clusters)
142+
gen_results_nested_jit_mv = gen_algorithm_nested_jit_mv(keys, sub_clusters)
143+
144+
# print(benchmark_fn(gen_results_naive))
145+
# print(benchmark_fn(gen_results_mv))
146+
# print(benchmark_fn(gen_results_nested_mv))
147+
print(benchmark_fn(gen_results_nested_jit_mv))
148+
# import cProfile, pstats, io
149+
# pr = cProfile.Profile()
150+
# pr.enable()
151+
# pr.disable()
152+
# s = io.StringIO()
153+
# sortby = 'tottime'
154+
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
155+
# ps.print_stats()
156+
# print(s.getvalue())
157+
# print(benchmark_fn(gen_results_nested_mv))
158+
159+
if args.print_results:
160+
print('naive')
161+
print_results(gen_results_naive(), keys, sub_clusters)
162+
print('\nmv')
163+
print_results(gen_results_mv(), keys, sub_clusters)
164+
print('\nnested_mv')
165+
print_results(gen_results_nested_mv(), keys, sub_clusters)

0 commit comments

Comments
 (0)