diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 309395b929c359..f702286a38994c 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -175,7 +175,14 @@ c10::optional SchemaTypeParser::tryToParseDeviceType() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - device_idx = c10::stoi(num, &num_len); + try { + device_idx = c10::stoi(num, &num_len); + } catch (const std::invalid_argument& e) { + throw ErrorReport(L.cur()) + << "Device index cannot be converted to integer"; + } catch (const std::out_of_range& e) { + throw ErrorReport(L.cur()) << "Device index is too long"; + } } if (dev == "cuda") { return c10::Device(at::kCUDA, device_idx); @@ -192,7 +199,15 @@ c10::optional SchemaTypeParser::tryToParseRequiresGrad() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - return (bool)c10::stoi(num, &num_len); + + try { + return (bool)c10::stoi(num, &num_len); + } catch (const std::invalid_argument& e) { + throw ErrorReport(L.cur()) + << "Field requires_grad cannot be converted to integer"; + } catch (const std::out_of_range& e) { + throw ErrorReport(L.cur()) << "Field requires_grad is too long"; + } } TypePtr SchemaTypeParser::parseRefinedTensor() { @@ -245,8 +260,15 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - auto stride = c10::stoll(num, &num_len); - strides.push_back(stride); + try { + auto stride = c10::stoll(num, &num_len); + strides.push_back(stride); + } catch (const std::invalid_argument& e) { + throw ErrorReport(L.cur()) + << "The stride value cannot be converted to int"; + } catch (const std::out_of_range& e) { + throw ErrorReport(L.cur()) << "The stride is too big"; + } }); return; } @@ -277,7 +299,14 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - int64_t dim = c10::stoll(num, &num_len); + int64_t dim = 0; + try { + dim = c10::stoll(num, &num_len); + } catch (const std::invalid_argument& e) { + throw ErrorReport(L.cur()) << "The number can't be converted to int"; + } catch (const std::out_of_range& e) { + throw ErrorReport(L.cur()) << "Number is too big"; + } if (shape_symbol) { L.expect(')'); dim = -dim; diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 8a132a29fd9b7e..25c04a00e7ff28 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -189,16 +189,40 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) { str += L.cur().text(); if (str.find('j') != std::string::npos) { r.k = AttributeKind::c; - auto imag = c10::stod(str.substr(0, str.size() - 1)); + double imag = 0.0f; + try { + imag = c10::stod(str.substr(0, str.size() - 1)); + } catch (const std::invalid_argument& e) { + throw ErrorReport(token.range) + << "Number cannot be converted to double"; + } catch (const std::out_of_range& e) { + throw ErrorReport(token.range) + << "Number is too long to be represented in type double"; + } r.c = c10::complex(0, imag); } else if ( str.find('.') != std::string::npos || str.find('e') != std::string::npos) { r.k = AttributeKind::f; - r.f = c10::stod(str); + try { + r.f = c10::stod(str); + } catch (const std::invalid_argument& e) { + throw ErrorReport(token.range) + << "Number cannot be converted to double"; + } catch (const std::out_of_range& e) { + throw ErrorReport(token.range) + << "Number is too long to be represented in type double"; + } } else { r.k = AttributeKind::i; - r.i = c10::stoll(str); + try { + r.i = c10::stoll(str); + } catch (const std::invalid_argument& e) { + throw ErrorReport(token.range) + << "Number cannot be converted to integer"; + } catch (const std::out_of_range& e) { + throw ErrorReport(token.range) << "Number is too big"; + } } L.next(); return r;