Skip to content

Commit

Permalink
Reduce paren explosion in IR printer. (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
resistor authored and Mikhail Zolotukhin committed Mar 3, 2020
1 parent 12ac00b commit 36e8a6f
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 40 deletions.
10 changes: 5 additions & 5 deletions test/cpp/tensorexpr/test_ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void testIRPrinterBasicValueTest() {

std::stringstream ss;
ss << c;
EXPECT_EQ(ss.str(), "(2 + 3)");
EXPECT_EQ(ss.str(), "2 + 3");
}

void testIRPrinterBasicValueTest02() {
Expand All @@ -31,7 +31,7 @@ void testIRPrinterBasicValueTest02() {

std::stringstream ss;
ss << f;
EXPECT_EQ(ss.str(), "((2.f + 3.f) - (4.f + 5.f))");
EXPECT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
}

void testIRPrinterLetTest01() {
Expand All @@ -43,7 +43,7 @@ void testIRPrinterLetTest01() {

std::stringstream ss;
ss << result;
EXPECT_EQ(ss.str(), "(let x = 3.f in (2.f + ((x * 3.f) + 4.f)))");
EXPECT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
}

void testIRPrinterLetTest02() {
Expand All @@ -58,7 +58,7 @@ void testIRPrinterLetTest02() {
std::stringstream ss;
ss << e2;
EXPECT_EQ(
ss.str(), "(let y = 6.f in (let x = 3.f in (2.f + ((x * 3.f) + (4.f * y)))))");
ss.str(), "let y = 6.f in (let x = 3.f in 2.f + (x * 3.f + 4.f * y))");
}

void testIRPrinterCastTest() {
Expand All @@ -74,7 +74,7 @@ void testIRPrinterCastTest() {
ss << e2;
EXPECT_EQ(
ss.str(),
"(let y = 6.f in (let x = int(3.f) in (2.f + ((x * 3.f) + (4.f * y)))))");
"let y = 6.f in (let x = int(3.f) in 2.f + (x * 3.f + 4.f * y))");
}
} // namespace jit
} // namespace torch
30 changes: 28 additions & 2 deletions torch/csrc/jit/tensorexpr/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,44 @@ namespace torch {
namespace jit {
namespace tensorexpr {

enum IRNodeType {
kPrimitive,
kAdd,
kSub,
kMul,
kDiv,
kMod,
kMax,
kMin,
kAnd,
kOr,
kLshift,
kRshift,
kXor,
kCompareSelect,
kLet,
kCast,
kNone
};

// The common base between all expression node.
class Expr : public KernelScopedObject {
public:
explicit Expr(Dtype dtype) : dtype_(dtype) {}
explicit Expr(Dtype dtype, IRNodeType expr_type = kNone)
: dtype_(dtype), expr_type_(expr_type) {}
Dtype dtype() const {
return dtype_;
}
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;

IRNodeType expr_type() const {
return expr_type_;
}

private:
Dtype dtype_;
IRNodeType expr_type_;
};

// A CRTP pattern to accept visitors for children class,
Expand Down Expand Up @@ -121,7 +147,7 @@ class Var : public ExprNode<Var> {
}

Var(const std::string& name_hint, Dtype dtype)
: ExprNodeBase(dtype), name_hint_(name_hint) {}
: ExprNodeBase(dtype, kPrimitive), name_hint_(name_hint) {}

private:
std::string name_hint_;
Expand Down
66 changes: 40 additions & 26 deletions torch/csrc/jit/tensorexpr/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@ namespace torch {
namespace jit {
namespace tensorexpr {

enum IRNodeType {
kAdd,
kSub,
kMul,
kDiv,
kMod,
kMax,
kMin,
kAnd,
kOr,
kLshift,
kRshift,
kXor,
kCompareSelect,
};

enum CompareSelectOperation {
kEQ,
kGT,
Expand All @@ -35,6 +19,41 @@ enum CompareSelectOperation {
kNE,
};

inline int getPrecedence(IRNodeType ty) {
// Match C++ operator precedence rules, since some pretty-print expressions to C++.
// SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
switch (ty) {
case kPrimitive:
return 0;
case kCast:
return 2;
case kAdd:
case kSub:
return 6;
case kMul:
case kDiv:
case kMod:
return 5;
case kMax:
case kMin:
return 99;
case kAnd:
return 11;
case kOr:
return 13;
case kLshift:
case kRshift:
return 7;
case kXor:
return 12;
case kCompareSelect:
case kLet:
return 16;
default:
return 99;
}
}

class Buffer;

class Cast : public ExprNode<Cast> {
Expand All @@ -46,7 +65,7 @@ class Cast : public ExprNode<Cast> {
return ExprHandle(new Cast(dtype, src_value.node()));
}
Cast(Dtype dtype, const Expr* src_value)
: ExprNodeBase(dtype), src_value_(src_value) {}
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}

private:
const Expr* src_value_;
Expand All @@ -68,9 +87,6 @@ class BinaryOpNode : public ExprNode<Op> {
const Expr* rhs() const {
return this->rhs_;
}
IRNodeType expr_type() const {
return expr_type_;
}

static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
return ExprHandle(new Op(lhs.node(), rhs.node()));
Expand All @@ -81,10 +97,9 @@ class BinaryOpNode : public ExprNode<Op> {
const Expr* rhs_v,
IRNodeType expr_type,
ScalarType ret_type = ScalarType::None)
: ExprNode<Op>(BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type)),
: ExprNode<Op>(BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type), expr_type),
lhs_(CastIfNeeded(lhs_v, ExprNode<Op>::dtype())),
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())),
expr_type_(expr_type) {}
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) { }

private:
static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) {
Expand All @@ -96,7 +111,6 @@ class BinaryOpNode : public ExprNode<Op> {

const Expr* lhs_;
const Expr* rhs_;
IRNodeType expr_type_;
};

class Add : public BinaryOpNode<Add> {
Expand Down Expand Up @@ -216,7 +230,7 @@ class Min : public BinaryOpNode<Min> {
#define IMM_DECLARE(Type, Name) \
class Name##Imm : public ExprNode<Name##Imm> { \
public: \
Name##Imm(Type value) : ExprNodeBase(k##Name), value_(value) {} \
Name##Imm(Type value) : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
Type value() const { \
return value_; \
} \
Expand Down Expand Up @@ -248,7 +262,7 @@ class Let : public ExprNode<Let> {
}

Let(const Expr* var, const Expr* value, const Expr* body)
: ExprNodeBase(body->dtype()), var_(var), value_(value), body_(body) {}
: ExprNodeBase(body->dtype(), kLet), var_(var), value_(value), body_(body) {}

private:
const Expr* var_;
Expand Down
63 changes: 56 additions & 7 deletions torch/csrc/jit/tensorexpr/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,30 @@ template <typename Op>
void visitBinaryOp(
const BinaryOpNode<Op>* v,
const std::string& op_str,
IRPrinter* printer) {
IRPrinter* printer,
bool parens = true) {
std::ostream& os = printer->os();
os << "(";
int self_prec = getPrecedence(v->expr_type());
int lhs_prec = getPrecedence(v->lhs()->expr_type());
int rhs_prec = getPrecedence(v->rhs()->expr_type());

if (lhs_prec >= self_prec) {
os << "(";
}
v->lhs()->accept(printer);
if (lhs_prec >= self_prec) {
os << ")";
}

os << " " << op_str << " ";

if (rhs_prec >= self_prec) {
os << "(";
}
v->rhs()->accept(printer);
os << ")";
if (rhs_prec >= self_prec) {
os << ")";
}
}

void IRPrinter::visit(const Add* v) {
Expand Down Expand Up @@ -95,8 +112,17 @@ void IRPrinter::visit(const Min* v) {

void IRPrinter::visit(const CompareSelect* v) {
CompareSelectOperation cmp_op = v->compare_select_op();
os() << "(";
int self_prec = getPrecedence(v->expr_type());
int lhs_prec = getPrecedence(v->lhs()->expr_type());
int rhs_prec = getPrecedence(v->rhs()->expr_type());

if (lhs_prec >= self_prec) {
os() << "(";
}
v->lhs()->accept(this);
if (lhs_prec >= self_prec) {
os() << ")";
}
switch (cmp_op) {
case CompareSelectOperation::kEQ:
os() << "==";
Expand All @@ -119,8 +145,14 @@ void IRPrinter::visit(const CompareSelect* v) {
default:
throw std::runtime_error("invalid compare select operator");
}

if (rhs_prec >= self_prec) {
os() << "(";
}
v->rhs()->accept(this);
os() << ")";
if (rhs_prec >= self_prec) {
os() << ")";
}
}

static void formatFPSuffix(std::ostream& os, double v) {
Expand Down Expand Up @@ -170,13 +202,30 @@ void IRPrinter::visit(const Var* v) {
}

void IRPrinter::visit(const Let* v) {
os() << "(let ";
int self_prec = getPrecedence(v->expr_type());
int value_prec = getPrecedence(v->value()->expr_type());
int body_prec = getPrecedence(v->body()->expr_type());
os() << "let ";
v->var()->accept(this);
os() << " = ";

if (value_prec >= self_prec) {
os() << "(";
}
v->value()->accept(this);
if (value_prec >= self_prec) {
os() << ")";
}

os() << " in ";

if(body_prec >= self_prec) {
os() << "(";
}
v->body()->accept(this);
os() << ")";
if (body_prec >= self_prec) {
os() << ")";
}
}

void IRPrinter::visit(const LetStmt* v) {
Expand Down

0 comments on commit 36e8a6f

Please sign in to comment.