Skip to content

Commit d397ddf

Browse files
committedFeb 2, 2024·
update flax patch
1 parent 9fa530b commit d397ddf

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed
 

‎.github/container/manifest.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ grain:
139139
latest_verified_commit: fa79b9dea81ffb00555a6c2ae2898be4bdd5e564
140140
mode: git-clone
141141
mujoco-mpc:
142-
url: https://github.com/google-deepmind/mujoco_mpc.git
142+
url: https://github.com/google-deepmind/mujoco_mpc.git
143143
tracking_ref: main
144144
latest_verified_commit: a8bdcc1d968addf0d334243ce9bd6dbb07e1d781
145145
mode: git-clone
146146
language-to-reward-2023:
147-
url: https://github.com/google-deepmind/language_to_reward_2023.git
147+
url: https://github.com/google-deepmind/language_to_reward_2023.git
148148
tracking_ref: main
149149
latest_verified_commit: abb8e5125e4ecd0da378490b73448c05a694def5
150150
mode: git-clone

‎.github/container/patches/flax/mirror-patch-t5x-layers-rebase-0131.patch

+53-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
From a563a57ca081e87b537bf724e2e29b138a14bb3f Mon Sep 17 00:00:00 2001
22
From: ashors1 <ashors@nvidia.com>
33
Date: Fri, 2 Jun 2023 15:01:21 -0700
4-
Subject: [PATCH 1/2] add t5x sharding annotations to flax layers
4+
Subject: [PATCH 1/3] add t5x sharding annotations to flax layers
55

66
---
77
flax/linen/attention.py | 34 +++++++++++++++++++++++-------
@@ -387,7 +387,7 @@ index cec8f508..e815f955 100644
387387
From 4ddde2bc878d9ff841f34e3826bbd1bce705766d Mon Sep 17 00:00:00 2001
388388
From: Terry Kong <terrycurtiskong@gmail.com>
389389
Date: Mon, 2 Oct 2023 16:10:05 -0700
390-
Subject: [PATCH 2/2] Added ConvTranspose sharding annotations (#3)
390+
Subject: [PATCH 2/3] Added ConvTranspose sharding annotations (#3)
391391

392392
Co-authored-by: sahilj <sahilj@nvidia.com>
393393
---
@@ -446,3 +446,54 @@ index 4656abf9..187ab6f5 100644
446446
--
447447
2.25.1
448448

449+
450+
From 9e68f4758d8a87553f3989e0a014974813e12390 Mon Sep 17 00:00:00 2001
451+
From: ashors1 <ashors@nvidia.com>
452+
Date: Thu, 1 Feb 2024 09:54:25 -0800
453+
Subject: [PATCH 3/3] Add missing import
454+
455+
---
456+
flax/linen/attention.py | 1 +
457+
flax/linen/linear.py | 1 +
458+
flax/linen/normalization.py | 1 +
459+
3 files changed, 3 insertions(+)
460+
461+
diff --git a/flax/linen/attention.py b/flax/linen/attention.py
462+
index 2827119f..517ff2dc 100644
463+
--- a/flax/linen/attention.py
464+
+++ b/flax/linen/attention.py
465+
@@ -39,6 +39,7 @@ from flax.typing import (
466+
Initializer,
467+
PrecisionLike,
468+
DotGeneralT,
469+
+ Tuple,
470+
)
471+
472+
def dot_product_attention_weights(
473+
diff --git a/flax/linen/linear.py b/flax/linen/linear.py
474+
index 187ab6f5..759406ed 100644
475+
--- a/flax/linen/linear.py
476+
+++ b/flax/linen/linear.py
477+
@@ -47,6 +47,7 @@ from flax.typing import (
478+
ConvGeneralDilatedT,
479+
PaddingLike,
480+
LaxPadding,
481+
+ Tuple,
482+
)
483+
484+
485+
diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py
486+
index e815f955..9e59e2be 100644
487+
--- a/flax/linen/normalization.py
488+
+++ b/flax/linen/normalization.py
489+
@@ -32,6 +32,7 @@ from flax.typing import (
490+
Shape as Shape,
491+
Initializer,
492+
Axes,
493+
+ Tuple,
494+
)
495+
496+
field = dataclasses.field
497+
--
498+
2.25.1
499+

0 commit comments

Comments
 (0)
Please sign in to comment.