Skip to content

Commit c703d33

Browse files
committed
Add support for setting the GCS credential
Set the credential for both the TPU and the local notebook. https://b.corp.google.com/issues/158133824
1 parent ad4d638 commit c703d33

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

patches/kaggle_secrets.py

+16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import socket
10+
import tensorflow_gcs_config
1011
import urllib.request
1112
from datetime import datetime, timedelta
1213
from enum import Enum, unique
@@ -135,6 +136,21 @@ def get_gcloud_credential(self) -> str:
135136
else:
136137
raise
137138

139+
def set_tensorflow_credential(self, credential):
140+
"""Sets the credential for use by Tensorflow both in the local notebook
141+
and to pass to the TPU.
142+
"""
143+
# Write to a local JSON credentials file and set
144+
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
145+
adc_path = os.path.join(
146+
os.environ.get('HOME', '/'), 'gcloud_credential.json')
147+
with open(adc_path, 'w') as f:
148+
f.write(credential)
149+
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path
150+
151+
# set the credential for the TPU
152+
tensorflow_gcs_config.configure_gcs(credentials=credential)
153+
138154
def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
139155
"""Retrieves BigQuery access token information from the UserSecrets service.
140156

tests/test_tensorflow_credentials.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import unittest
2+
3+
import os
4+
import tensorflow_gcs_config
5+
from unittest.mock import patch
6+
from test.support import EnvironmentVarGuard
7+
from kaggle_secrets import UserSecretsClient
8+
9+
class TestTensorflowCredentials(unittest.TestCase):
10+
11+
@patch('tensorflow_gcs_config.configure_gcs')
12+
def test_set_tensorflow_credential(self, mock_configure_gcs):
13+
credential = '{"client_id":"fake_client_id",' \
14+
'"client_secret":"fake_client_secret",' \
15+
'"refresh_token":"not a refresh token",' \
16+
'"type":"authorized_user"}';
17+
18+
env = EnvironmentVarGuard()
19+
env.set('HOME', '/tmp')
20+
env.set('GOOGLE_APPLICATION_CREDENTIALS', '')
21+
22+
# These need to be set to make UserSecretsClient happy, but aren't
23+
# pertinent to this test.
24+
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
25+
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
26+
27+
user_secrets = UserSecretsClient()
28+
user_secrets.set_tensorflow_credential(credential)
29+
30+
credential_path = '/tmp/gcloud_credential.json'
31+
self.assertEqual(
32+
credential_path, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
33+
with open(credential_path, 'r') as f:
34+
saved_cred = f.read()
35+
self.assertEqual(credential, saved_cred)
36+
37+
mock_configure_gcs.assert_called_with(credentials=credential)

0 commit comments

Comments
 (0)