Skip to content

Commit 692996e

Browse files
committed
updated.
1 parent 89df0ca commit 692996e

File tree

5 files changed

+10
-7
lines changed

5 files changed

+10
-7
lines changed

docs/index.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
**Red Coast** (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.
44

5-
RedCoast supports *Large Models* + *Complex Algorithms*, in a *lightweight* and *user-friendly* manner:
5+
RedCoast supports *Large Models* + *Complex Algorithms*, in a *lightweight* and *user-friendly* way:
66

77
* Large Models beyond Transformers, e.g, [Stable Diffusion](https://github.com/tanyuqian/redco/tree/master/examples/text_to_image), etc.
8-
* Complex algorithms beyond cross entropy, e.g., [Meta Learning](https://github.com/tanyuqian/redco/tree/master/examples/meta_learning), etc.
8+
* Complex algorithms beyond cross entropy, e.g., [Meta Learning](https://github.com/tanyuqian/redco/tree/master/examples/meta_learning), [DP Training](https://github.com/tanyuqian/redco/tree/master/examples/differential_private_training), etc.
99

1010
With RedCoast, to define a ML pipeline, only three functions are needed:
1111

docs/mnist.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
This is a trivial MNIST example with RedCoast. Runnable by
1+
This is a trivial MNIST example with RedCoast (`pip install redco==0.4.22`). Runnable by
22
```
33
python main.py
44
```
@@ -47,14 +47,14 @@ def collate_fn(examples):
4747

4848

4949
# Loss function converting model inputs to a scalar loss
50-
def loss_fn(train_rng, state, params, batch, is_training):
50+
def loss_fn(rng, state, params, batch, is_training):
5151
logits = state.apply_fn({'params': params}, batch['images'])
5252
return optax.softmax_cross_entropy_with_integer_labels(
5353
logits=logits, labels=batch['labels']).mean()
5454

5555

5656
# Predict function converting model inputs to the model outputs
57-
def pred_fn(pred_rng, params, batch, model):
57+
def pred_fn(rng, params, batch, model):
5858
accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)
5959
return {'acc': accs}
6060

redco/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = '0.4.21'
15+
__version__ = '0.4.22'
1616

1717
from .deployers import *
1818
from .trainers import *

redco/deployers/deployer.py

+3
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def gen_rng(self):
276276
return new_rng
277277

278278
def gen_model_step_rng(self):
279+
"""Get a new random number generator key for distributed model step and
280+
update the random state.
281+
"""
279282
rng = self.gen_rng()
280283
if self.mesh is None:
281284
rng = jax.random.split(

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
setup(
1919
name="redco",
20-
version="0.4.21",
20+
version="0.4.22",
2121
author="Bowen Tan",
2222
packages=find_packages(),
2323
install_requires=['jax', 'flax', 'optax'],

0 commit comments

Comments
 (0)