Skip to content

Commit d4ed594

Browse files
prm-james-hillpytorchmergebot
authored andcommitted
Fix floating point literals in IRPrinter (pytorch#142119)
Fixes pytorch#114035 This is a recreation of pytorch#140002 with approval from its author. Original description: >when v larger than 1e16, the format will be error. example: v is 1.2e17, the output is 1.2e17.f, it have two point '.' Pull Request resolved: pytorch#142119 Approved by: https://github.com/jgong5, https://github.com/malfet
1 parent 10b9c59 commit d4ed594

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

test/cpp/tensorexpr/test_ir_printer.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ TEST(IRPrinter, BasicValueTest02) {
3737
ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
3838
}
3939

40+
TEST(IRPrinter, BasicValueTest03) {
41+
ExprHandle a(3.402823466385289e+38f);
42+
ExprHandle b(-3.402823466385289e+38f);
43+
std::stringstream ss;
44+
ss << a << ", " << b;
45+
ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f");
46+
}
47+
4048
TEST(IRPrinter, CastTest) {
4149
VarHandle x("x", kHalf);
4250
VarHandle y("y", kFloat);

torch/csrc/jit/tensorexpr/ir_printer.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,27 @@ void IRPrinter::visit(const CompareSelectPtr& v) {
191191
withParens(v->ret_val2());
192192
}
193193

194-
static void formatFPSuffix(std::ostream& os, double v) {
195-
os << (v == std::ceil(v) ? ".0" : "");
194+
static void formatFPSuffix(std::ostream& os, double v, bool flag) {
195+
os << (flag && v == std::ceil(v) ? ".0" : "");
196196
}
197197

198198
template <typename T>
199-
static void formatFPSuffix(std::ostream& os, T v) {
200-
os << (v == std::ceil(v) ? ".f" : "f");
199+
static void formatFPSuffix(std::ostream& os, T v, bool flag) {
200+
os << (flag && v == std::ceil(v) ? ".f" : "f");
201201
}
202202

203203
template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
204204
static void formatImm(std::ostream& os, T v) {
205205
const int precision = 16;
206+
const T lower_bound = static_cast<T>(-std::pow(10, precision));
207+
const T upper_bound = -lower_bound;
206208
if (std::isnan(v)) {
207209
os << "NAN";
208210
} else if (std::isinf(v)) {
209211
os << (v > 0 ? "POS_INFINITY" : "NEG_INFINITY");
210212
} else {
211213
os << std::setprecision(precision) << v;
212-
formatFPSuffix(os, v);
214+
formatFPSuffix(os, v, v > lower_bound && v < upper_bound);
213215
}
214216
}
215217

0 commit comments

Comments
 (0)