Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 091373c

Browse files
afrozenatorCopybara-Service
authored and
Copybara-Service
committed
Fix Problem.feature_info.
PiperOrigin-RevId: 219247790
1 parent 66afb76 commit 091373c

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

tensor2tensor/data_generators/problem.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -716,23 +716,17 @@ def feature_info(self):
716716
assert self._hparams is not None
717717

718718
hp = self.get_hparams()
719-
input_mods = hp.modality["inputs"]
720-
target_mod = hp.modality["targets"]
721-
vocabs = hp.vocabulary
722719
if self.has_inputs:
723720
in_id = hp.input_space_id
724721
out_id = hp.target_space_id
725722

726723
features = collections.defaultdict(FeatureInfo)
724+
for feature_name, modality_cls in six.iteritems(hp.modality):
725+
finfo = features[feature_name]
726+
finfo.modality = modality_cls
727+
finfo.vocab_size = modality_cls.top_dimensionality
727728

728-
for name, mod in six.iteritems(input_mods):
729-
finfo = features[name]
730-
finfo.modality = mod
731-
finfo.vocab_size = mod.top_dimensionality
732-
733-
features["targets"].modality = target_mod
734-
features["targets"].vocab_size = target_mod.top_dimensionality
735-
729+
vocabs = hp.vocabulary
736730
for name, encoder in six.iteritems(vocabs):
737731
features[name].encoder = encoder
738732

0 commit comments

Comments
 (0)