forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
code_template.h
238 lines (224 loc) · 6.56 KB
/
code_template.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
// A template environment is a mapping from template variable names, e.g.,
// identifier (corresponding to $identifier) to their expansions.
//
// This template environment supports storing strings, numbers and lists
// of strings, and can be chained together (so that lookup proceeds in
// in the top level environment, and then recurses into a parent
// environment if the key is not found.)
struct TemplateEnv {
TemplateEnv() : parent(nullptr) {}
TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
using string_list = std::vector<std::string>;
// Add a string 'v' to the map at key 'k'.
void s(const std::string& k, const std::string& v) {
strings_[k] = v;
lists_.erase(k);
}
// Add a number 'v' to the map at key 'k'
template <typename T>
void d(const std::string& k, const T& v) {
strings_[k] = std::to_string(v);
lists_.erase(k);
}
// Retrieve the string representation of the value stored at 'k' from the map.
// Raises an exception if the key is not found.
const std::string& s(const std::string& k) const {
if (strings_.count(k) == 0) {
if (parent) {
return parent->s(k);
}
notFound(k);
}
return strings_.at(k);
}
// Store a list of strings 'v' in the map at 'k'.
void v(const std::string& k, const string_list& v) {
lists_[k] = v;
strings_.erase(k);
}
// Retrieve a list of strings stored at 'k' from the map.
// Raises an exception if the key is not found.
const string_list& v(const std::string& k) const {
if (lists_.count(k) == 0) {
if (parent) {
return parent->v(k);
}
notFound(k);
}
return lists_.at(k);
}
// Test if a string 'k' is a string (as opposed to a list.)
bool keyIsString(const std::string& k) const {
if (strings_.count(k) > 0)
return true;
if (lists_.count(k) > 0)
return false;
if (parent)
return parent->keyIsString(k);
notFound(k);
}
private:
[[noreturn]] void notFound(const std::string& k) const {
std::stringstream ss;
ss << "key not found: " << k;
throw std::logic_error(ss.str());
}
std::unordered_map<std::string, std::string> strings_;
std::unordered_map<std::string, string_list> lists_;
TemplateEnv* parent;
};
/*
# Match $identifier or ${identifier} and replace with the value in env.
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting all lines of all elements.
# If the identifier is on a line starting with non-whitespace and a list
# then it is comma separated. ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.
*/
struct CodeTemplate {
/* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
std::string format(const TemplateEnv& env) const {
std::stringstream out;
size_t pos = 0;
size_t indent = 0;
bool all_whitespace = true;
while (pos < template_text.size()) {
char c = template_text[pos];
if (c == '$') {
std::stringstream kss;
bool comma_before;
bool comma_after;
size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
std::string k = kss.str();
bool is_string = env.keyIsString(k);
if (all_whitespace) {
if (is_string)
emitStringWithIndents(out, indent, env.s(k));
else
emitLinesIndented(out, indent, env.v(k));
} else {
if (is_string)
out << env.s(k);
else
emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
}
all_whitespace = false;
pos = new_pos;
} else {
out << c;
if (!isspace(c))
all_whitespace = false;
indent++;
if (c == '\n') {
indent = 0;
all_whitespace = true;
}
pos++;
}
}
return out.str();
}
private:
using string_list = std::vector<std::string>;
char charAt(size_t p) const {
if (p >= template_text.size())
throw std::logic_error("EOS found in key");
return template_text[p];
}
size_t parseKey(
size_t pos,
std::ostream& k,
bool& comma_before,
bool& comma_after) const {
comma_before = false;
comma_after = false;
pos++;
if (charAt(pos) == '{') {
pos++;
if (charAt(pos) == ',') {
comma_before = true;
pos++;
}
pos = parseIdent(pos, k);
if (charAt(pos) == ',') {
comma_after = true;
pos++;
}
if (charAt(pos) != '}')
throw std::logic_error("missing terminating '}'");
pos++;
return pos;
} else {
return parseIdent(pos, k);
}
}
size_t parseIdent(size_t pos, std::ostream& k) const {
while (pos < template_text.size() &&
(isalnum(template_text[pos]) || template_text[pos] == '_')) {
k << template_text[pos];
pos++;
}
return pos;
}
void emitCommaSeparatedList(
std::ostream& out,
const string_list& strings,
bool comma_before,
bool comma_after) const {
if (comma_before && strings.size() > 0)
out << ", ";
for (size_t i = 0; i < strings.size(); ++i) {
if (i > 0)
out << ", ";
out << strings[i];
}
if (comma_after && strings.size() > 0)
out << ", ";
}
// These indentation functions follow the convention that they never emit
// leading or trailing newlines when the input string does not have leading
// or trailing newlines. It's the responsibility of the calling function
// to indent correctly in the context.
void emitIndent(std::ostream& out, size_t indent) const {
for (size_t i = 0; i < indent; ++i) {
out << " ";
}
}
void emitStringWithIndents(
std::ostream& out,
size_t indent,
const std::string& str) const {
for (auto c : str) {
out << c;
if (c == '\n') {
emitIndent(out, indent);
}
}
}
void emitLinesIndented(
std::stringstream& out,
size_t indent,
const string_list& strings) const {
for (size_t i = 0; i < strings.size(); ++i) {
if (i > 0)
emitIndent(out, indent);
emitStringWithIndents(out, indent, strings[i]);
if (i + 1 != strings.size())
out << "\n";
}
}
std::string template_text;
};
static inline std::string format(const std::string& fmt, TemplateEnv& env) {
return CodeTemplate(fmt).format(env);
}
} // namespace jit
} // namespace torch