Skip to content

Commit 1ea79ba

Browse files
committed
use .sizes instead of .dims to avoid warning and future deprecation
1 parent 8debd72 commit 1ea79ba

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

graphcast/autoregressive.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _validate_targets_and_forcings(self, targets, forcings):
112112
f'forcings, which isn\'t allowed: {overlap}')
113113

114114
def _update_inputs(self, inputs, next_frame):
115-
num_inputs = inputs.dims['time']
115+
num_inputs = inputs.sizes['time']
116116

117117
predicted_or_forced_inputs = next_frame[list(inputs.keys())]
118118

@@ -199,7 +199,7 @@ def one_step_prediction(inputs, scan_variables):
199199
return next_inputs, flat_pred
200200

201201
if self._gradient_checkpointing:
202-
scan_length = targets_template.dims['time']
202+
scan_length = targets_template.sizes['time']
203203
if scan_length <= 1:
204204
logging.warning(
205205
'Skipping gradient checkpointing for sequence length of 1')

graphcast/rollout.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def chunked_prediction_generator(
124124
if "datetime" in forcings.coords:
125125
del forcings.coords["datetime"]
126126

127-
num_target_steps = targets_template.dims["time"]
127+
num_target_steps = targets_template.sizes["time"]
128128
num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
129129
if remainder != 0:
130130
raise ValueError(
@@ -202,7 +202,7 @@ def _get_next_inputs(
202202
next_inputs = next_frame[next_inputs_keys]
203203

204204
# Apply concatenate next frame with inputs, crop what we don't need.
205-
num_inputs = prev_inputs.dims["time"]
205+
num_inputs = prev_inputs.sizes["time"]
206206
return (
207207
xarray.concat(
208208
[prev_inputs, next_inputs], dim="time", data_vars="different")

0 commit comments

Comments
 (0)