diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index aa0dfb0a..22802979 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -25,7 +25,7 @@ from firebase_admin import _rfc3339 from firebase_admin import _user_identifier from firebase_admin import _user_import -from firebase_admin._user_import import ErrorInfo +from firebase_admin._user_import import ErrorInfo, UserProvider MAX_LIST_USERS_RESULTS = 1000 @@ -688,7 +688,8 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None, providers_to_delete=None): + valid_since=None, custom_claims=None, providers_to_delete=None, + provider_to_add: UserProvider | None=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -727,6 +728,9 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, custom_claims, dict) else custom_claims payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) + if provider_to_add: + payload['linkProviderUserInfo'] = provider_to_add.to_dict() + if remove_provider: payload['deleteProvider'] = list(set(remove_provider)) diff --git a/integration/test_auth.py b/integration/test_auth.py index e1d01a25..dfcab5f8 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -16,6 +16,7 @@ import base64 import datetime import random +import re import string import time from typing import List @@ -30,6 +31,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin import exceptions _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' @@ -496,6 +498,32 @@ def test_disable_user(new_user_with_params): assert user.disabled is True assert len(user.provider_data) == 1 +def test_add_valid_provider(new_user_with_provider): + new_provider = auth.UserProvider(uid=new_user_with_provider.uid, provider_id='microsoft.com') + existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert new_provider.provider_id not in existing_provider_ids + user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider) + assert user.uid == new_user_with_provider.uid + new_provider_ids = [provider.provider_id for provider in user.provider_data] + assert sorted(new_provider_ids) == sorted(existing_provider_ids + [new_provider.provider_id]) + +def test_add_invalid_provider(new_user_with_provider): + new_provider = auth.UserProvider(uid=new_user_with_provider.uid, provider_id='xyz.com') + existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert new_provider.provider_id not in existing_provider_ids + with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( + f"Error while calling Auth service (INVALID_PROVIDER_ID ). provider {new_provider.provider_id} is not supported for linking." + )): + auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider) + +def test_add_duplicate_provider(new_user_with_provider): + google_uid, google_email = _random_id() + duplicate_provider = auth.UserProvider(uid=google_uid, provider_id='google.com', email=google_email) + with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( + f"Error while calling Auth service (PROVIDER_ALREADY_LINKED)." + )): + auth.update_user(new_user_with_provider.uid, provider_to_add=duplicate_provider) + def test_remove_provider(new_user_with_provider): provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] assert 'google.com' in provider_ids