@@ -100,25 +100,35 @@ TEST_F(AnyModuleTest, WrongNumberOfArguments) {
100
100
#endif
101
101
ASSERT_THROWS_WITH (
102
102
any.forward (),
103
- module_name + " 's forward() method expects 2 argument(s), but received 0. "
104
- " If " + module_name + " 's forward() method has default arguments, "
105
- " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
103
+ module_name +
104
+ " 's forward() method expects 2 argument(s), but received 0. "
105
+ " If " +
106
+ module_name +
107
+ " 's forward() method has default arguments, "
108
+ " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
106
109
ASSERT_THROWS_WITH (
107
110
any.forward (5 ),
108
- module_name + " 's forward() method expects 2 argument(s), but received 1. "
109
- " If " + module_name + " 's forward() method has default arguments, "
110
- " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
111
+ module_name +
112
+ " 's forward() method expects 2 argument(s), but received 1. "
113
+ " If " +
114
+ module_name +
115
+ " 's forward() method has default arguments, "
116
+ " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
111
117
ASSERT_THROWS_WITH (
112
118
any.forward (1 , 2 , 3 ),
113
- module_name + " 's forward() method expects 2 argument(s), but received 3." );
119
+ module_name +
120
+ " 's forward() method expects 2 argument(s), but received 3." );
114
121
}
115
122
116
123
struct M_default_arg_with_macro : torch::nn::Module {
117
124
double forward (int a, int b = 2 , double c = 3.0 ) {
118
125
return a + b + c;
119
126
}
127
+
120
128
protected:
121
- FORWARD_HAS_DEFAULT_ARGS ({1 , torch::nn::AnyValue (2 )}, {2 , torch::nn::AnyValue (3.0 )})
129
+ FORWARD_HAS_DEFAULT_ARGS (
130
+ {1 , torch::nn::AnyValue (2 )},
131
+ {2 , torch::nn::AnyValue (3.0 )})
122
132
};
123
133
124
134
struct M_default_arg_without_macro : torch::nn::Module {
@@ -127,7 +137,9 @@ struct M_default_arg_without_macro : torch::nn::Module {
127
137
}
128
138
};
129
139
130
- TEST_F (AnyModuleTest, PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
140
+ TEST_F (
141
+ AnyModuleTest,
142
+ PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
131
143
{
132
144
AnyModule any (M_default_arg_with_macro{});
133
145
@@ -155,22 +167,32 @@ TEST_F(AnyModuleTest, PassingArgumentsToModuleWithDefaultArgumentsInForwardMetho
155
167
156
168
ASSERT_THROWS_WITH (
157
169
any.forward (),
158
- module_name + " 's forward() method expects 3 argument(s), but received 0. "
159
- " If " + module_name + " 's forward() method has default arguments, "
160
- " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
170
+ module_name +
171
+ " 's forward() method expects 3 argument(s), but received 0. "
172
+ " If " +
173
+ module_name +
174
+ " 's forward() method has default arguments, "
175
+ " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
161
176
ASSERT_THROWS_WITH (
162
177
any.forward <double >(1 ),
163
- module_name + " 's forward() method expects 3 argument(s), but received 1. "
164
- " If " + module_name + " 's forward() method has default arguments, "
165
- " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
178
+ module_name +
179
+ " 's forward() method expects 3 argument(s), but received 1. "
180
+ " If " +
181
+ module_name +
182
+ " 's forward() method has default arguments, "
183
+ " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
166
184
ASSERT_THROWS_WITH (
167
185
any.forward <double >(1 , 3 ),
168
- module_name + " 's forward() method expects 3 argument(s), but received 2. "
169
- " If " + module_name + " 's forward() method has default arguments, "
170
- " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
186
+ module_name +
187
+ " 's forward() method expects 3 argument(s), but received 2. "
188
+ " If " +
189
+ module_name +
190
+ " 's forward() method has default arguments, "
191
+ " please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro." );
171
192
ASSERT_THROWS_WITH (
172
193
any.forward (1 , 2 , 3.0 , 4 ),
173
- module_name + " 's forward() method expects 3 argument(s), but received 4." );
194
+ module_name +
195
+ " 's forward() method expects 3 argument(s), but received 4." );
174
196
}
175
197
}
176
198
@@ -345,13 +367,14 @@ TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
345
367
ASSERT_NE (value.try_get <int >(), nullptr );
346
368
// const and non-const types have the same typeid(),
347
369
// but casting Holder<int> to Holder<const int> is undefined
348
- // behavior according to UBSAN: https://github.com/pytorch/pytorch/issues/26964
370
+ // behavior according to UBSAN:
371
+ // https://github.com/pytorch/pytorch/issues/26964
349
372
// ASSERT_NE(value.try_get<const int>(), nullptr);
350
373
ASSERT_EQ (value.get <int >(), 5 );
351
374
}
352
375
// This test does not work at all, because it looks like make_value
353
376
// decays const int into int.
354
- // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
377
+ // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
355
378
// auto value = make_value<const int>(5);
356
379
// ASSERT_NE(value.try_get<const int>(), nullptr);
357
380
// // ASSERT_NE(value.try_get<int>(), nullptr);
0 commit comments