Skip to content

Commit 245c307

Browse files
committed
fix: nullable types with gen ai bump
1 parent 0094eea commit 245c307

File tree

5 files changed

+62
-57
lines changed

5 files changed

+62
-57
lines changed

src/google/adk/auth/auth_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ async def exchange_auth_token(
4848
self,
4949
) -> AuthCredential:
5050
exchanger = OAuth2CredentialExchanger()
51-
return await exchanger.exchange(
51+
return (await exchanger.exchange(
5252
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
53-
)
53+
))[0]
5454

5555
async def parse_and_store_auth_response(self, state: State) -> None:
5656

src/google/adk/auth/credential_manager.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,12 @@ async def _exchange_credential(
214214
return credential, False
215215

216216
if isinstance(exchanger, ServiceAccountCredentialExchanger):
217-
exchanged_credential = exchanger.exchange_credential(
218-
self._auth_config.auth_scheme, credential
219-
)
220-
else:
221-
exchanged_credential = await exchanger.exchange(
222-
credential, self._auth_config.auth_scheme
217+
return (
218+
exchanger.exchange_credential(self._auth_config.auth_scheme, credential),
219+
True
223220
)
224221

225-
return exchanged_credential, True
222+
return await exchanger.exchange(credential, self._auth_config.auth_scheme)
226223

227224
async def _refresh_credential(
228225
self, credential: AuthCredential

src/google/adk/auth/exchanger/base_credential_exchanger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,18 @@ async def exchange(
4141
self,
4242
auth_credential: AuthCredential,
4343
auth_scheme: Optional[AuthScheme] = None,
44-
) -> AuthCredential:
44+
) -> tuple[AuthCredential, bool]:
4545
"""Exchange credential if needed.
4646
4747
Args:
4848
auth_credential: The credential to exchange.
4949
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
5050
5151
Returns:
52-
The exchanged credential.
52+
A tuple of (credential, exchanged) where:
53+
- credential: The exchanged credential if exchange occurred, otherwise
54+
the original credential.
55+
- exchanged: True if credential was exchanged, False otherwise.
5356
5457
Raises:
5558
CredentialExchangeError: If credential exchange fails.

src/google/adk/auth/exchanger/oauth2_credential_exchanger.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def exchange(
5151
self,
5252
auth_credential: AuthCredential,
5353
auth_scheme: Optional[AuthScheme] = None,
54-
) -> AuthCredential:
54+
) -> tuple[AuthCredential, bool]:
5555
"""Exchange OAuth2 credential from authorization response.
5656
5757
if credential exchange failed, the original credential will be returned.
@@ -61,7 +61,10 @@ async def exchange(
6161
auth_scheme: The OAuth2 authentication scheme.
6262
6363
Returns:
64-
The exchanged credential with access token.
64+
A tuple of (credential, exchanged) where:
65+
- credential: The exchanged credential with an access token if exchange occurred, otherwise
66+
the original credential.
67+
- exchanged: True if credential was exchanged, False otherwise.
6568
6669
Raises:
6770
CredentialExchangeError: If auth_scheme is missing.
@@ -79,10 +82,10 @@ async def exchange(
7982
logger.warning(
8083
"authlib is not available, skipping OAuth2 credential exchange."
8184
)
82-
return auth_credential
85+
return (auth_credential, False)
8386

8487
if auth_credential.oauth2 and auth_credential.oauth2.access_token:
85-
return auth_credential
88+
return (auth_credential, False)
8689

8790
# Determine grant type from auth_scheme
8891
grant_type = self._determine_grant_type(auth_scheme)
@@ -97,7 +100,7 @@ async def exchange(
97100
)
98101
else:
99102
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
100-
return auth_credential
103+
return auth_credential, False
101104

102105
def _determine_grant_type(
103106
self, auth_scheme: AuthScheme
@@ -129,22 +132,25 @@ async def _exchange_client_credentials(
129132
self,
130133
auth_credential: AuthCredential,
131134
auth_scheme: AuthScheme,
132-
) -> AuthCredential:
135+
) -> tuple[AuthCredential, bool]:
133136
"""Exchange client credentials for access token.
134137
135138
Args:
136139
auth_credential: The OAuth2 credential to exchange.
137140
auth_scheme: The OAuth2 authentication scheme.
138141
139142
Returns:
140-
The credential with access token.
143+
A tuple of (credential, exchanged) where:
144+
- credential: The exchanged credential with an access token if exchange occurred, otherwise
145+
the original credential.
146+
- exchanged: True if credential was exchanged, False otherwise.
141147
"""
142148
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
143149
if not client:
144150
logger.warning(
145151
"Could not create OAuth2 session for client credentials exchange"
146152
)
147-
return auth_credential
153+
return auth_credential, False
148154

149155
try:
150156
tokens = client.fetch_token(
@@ -155,9 +161,9 @@ async def _exchange_client_credentials(
155161
logger.debug("Successfully exchanged client credentials for access token")
156162
except Exception as e:
157163
logger.error("Failed to exchange client credentials: %s", e)
158-
return auth_credential
164+
return auth_credential, False
159165

160-
return auth_credential
166+
return auth_credential, True
161167

162168
def _normalize_auth_uri(self, auth_uri: str | None) -> str | None:
163169
# Authlib currently used a simplified token check by simply scanning hash existence,
@@ -171,22 +177,25 @@ async def _exchange_authorization_code(
171177
self,
172178
auth_credential: AuthCredential,
173179
auth_scheme: AuthScheme,
174-
) -> AuthCredential:
180+
) -> tuple[AuthCredential, bool]:
175181
"""Exchange authorization code for access token.
176182
177183
Args:
178184
auth_credential: The OAuth2 credential to exchange.
179185
auth_scheme: The OAuth2 authentication scheme.
180186
181187
Returns:
182-
The credential with access token.
188+
A tuple of (credential, exchanged) where:
189+
- credential: The exchanged credential with an access token if exchange occurred, otherwise
190+
the original credential.
191+
- exchanged: True if credential was exchanged, False otherwise.
183192
"""
184193
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
185194
if not client:
186195
logger.warning(
187196
"Could not create OAuth2 session for authorization code exchange"
188197
)
189-
return auth_credential
198+
return auth_credential, False
190199

191200
try:
192201
tokens = client.fetch_token(
@@ -201,6 +210,6 @@ async def _exchange_authorization_code(
201210
logger.debug("Successfully exchanged authorization code for access token")
202211
except Exception as e:
203212
logger.error("Failed to exchange authorization code: %s", e)
204-
return auth_credential
213+
return auth_credential, False
205214

206-
return auth_credential
215+
return auth_credential, True

tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
class TestOAuth2CredentialExchanger:
3434
"""Test suite for OAuth2CredentialExchanger."""
3535

36-
@pytest.mark.asyncio
3736
async def test_exchange_with_existing_token(self):
3837
"""Test exchange method when access token already exists."""
3938
scheme = OpenIdConnectWithConfig(
@@ -55,14 +54,14 @@ async def test_exchange_with_existing_token(self):
5554
)
5655

5756
exchanger = OAuth2CredentialExchanger()
58-
result = await exchanger.exchange(credential, scheme)
57+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
5958

6059
# Should return the same credential since access token already exists
61-
assert result == credential
62-
assert result.oauth2.access_token == "existing_token"
60+
assert exchanged_credential == credential
61+
assert exchanged_credential.oauth2.access_token == "existing_token"
62+
assert not was_exchanged
6363

6464
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
65-
@pytest.mark.asyncio
6665
async def test_exchange_success(self, mock_oauth2_session):
6766
"""Test successful token exchange."""
6867
# Setup mock
@@ -96,14 +95,14 @@ async def test_exchange_success(self, mock_oauth2_session):
9695
)
9796

9897
exchanger = OAuth2CredentialExchanger()
99-
result = await exchanger.exchange(credential, scheme)
98+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
10099

101100
# Verify token exchange was successful
102-
assert result.oauth2.access_token == "new_access_token"
103-
assert result.oauth2.refresh_token == "new_refresh_token"
101+
assert exchanged_credential.oauth2.access_token == "new_access_token"
102+
assert exchanged_credential.oauth2.refresh_token == "new_refresh_token"
103+
assert was_exchanged
104104
mock_client.fetch_token.assert_called_once()
105105

106-
@pytest.mark.asyncio
107106
async def test_exchange_missing_auth_scheme(self):
108107
"""Test exchange with missing auth_scheme raises ValueError."""
109108
credential = AuthCredential(
@@ -122,7 +121,6 @@ async def test_exchange_missing_auth_scheme(self):
122121
assert "auth_scheme is required" in str(e)
123122

124123
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
125-
@pytest.mark.asyncio
126124
async def test_exchange_no_session(self, mock_oauth2_session):
127125
"""Test exchange when OAuth2Session cannot be created."""
128126
# Mock to return None for create_oauth2_session
@@ -146,14 +144,14 @@ async def test_exchange_no_session(self, mock_oauth2_session):
146144
)
147145

148146
exchanger = OAuth2CredentialExchanger()
149-
result = await exchanger.exchange(credential, scheme)
147+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
150148

151149
# Should return original credential when session creation fails
152-
assert result == credential
153-
assert result.oauth2.access_token is None
150+
assert exchanged_credential == credential
151+
assert exchanged_credential.oauth2.access_token is None
152+
assert not was_exchanged
154153

155154
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
156-
@pytest.mark.asyncio
157155
async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
158156
"""Test exchange when fetch_token fails."""
159157
# Setup mock to raise exception during fetch_token
@@ -181,14 +179,14 @@ async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
181179
)
182180

183181
exchanger = OAuth2CredentialExchanger()
184-
result = await exchanger.exchange(credential, scheme)
182+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
185183

186184
# Should return original credential when fetch_token fails
187-
assert result == credential
188-
assert result.oauth2.access_token is None
185+
assert exchanged_credential == credential
186+
assert exchanged_credential.oauth2.access_token is None
187+
assert not was_exchanged
189188
mock_client.fetch_token.assert_called_once()
190189

191-
@pytest.mark.asyncio
192190
async def test_exchange_authlib_not_available(self):
193191
"""Test exchange when authlib is not available."""
194192
scheme = OpenIdConnectWithConfig(
@@ -217,14 +215,14 @@ async def test_exchange_authlib_not_available(self):
217215
"google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE",
218216
False,
219217
):
220-
result = await exchanger.exchange(credential, scheme)
218+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
221219

222220
# Should return original credential when authlib is not available
223-
assert result == credential
224-
assert result.oauth2.access_token is None
221+
assert exchanged_credential == credential
222+
assert exchanged_credential.oauth2.access_token is None
223+
assert not was_exchanged
225224

226225
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
227-
@pytest.mark.asyncio
228226
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
229227
"""Test successful client credentials exchange."""
230228
# Setup mock
@@ -255,17 +253,17 @@ async def test_exchange_client_credentials_success(self, mock_oauth2_session):
255253
)
256254

257255
exchanger = OAuth2CredentialExchanger()
258-
result = await exchanger.exchange(credential, scheme)
256+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
259257

260258
# Verify client credentials exchange was successful
261-
assert result.oauth2.access_token == "client_access_token"
259+
assert exchanged_credential.oauth2.access_token == "client_access_token"
260+
assert was_exchanged
262261
mock_client.fetch_token.assert_called_once_with(
263262
"https://example.com/token",
264263
grant_type="client_credentials",
265264
)
266265

267266
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
268-
@pytest.mark.asyncio
269267
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
270268
"""Test client credentials exchange failure."""
271269
# Setup mock to raise exception during fetch_token
@@ -292,15 +290,15 @@ async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
292290
)
293291

294292
exchanger = OAuth2CredentialExchanger()
295-
result = await exchanger.exchange(credential, scheme)
293+
exchanged_credential, was_exchanged = await exchanger.exchange(credential, scheme)
296294

297295
# Should return original credential when client credentials exchange fails
298-
assert result == credential
299-
assert result.oauth2.access_token is None
296+
assert exchanged_credential == credential
297+
assert exchanged_credential.oauth2.access_token is None
298+
assert not was_exchanged
300299
mock_client.fetch_token.assert_called_once()
301300

302301
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
303-
@pytest.mark.asyncio
304302
async def test_exchange_normalize_uri(self, mock_oauth2_session):
305303
"""Test exchange method normalizes auth_response_uri."""
306304
mock_client = Mock()
@@ -343,7 +341,6 @@ async def test_exchange_normalize_uri(self, mock_oauth2_session):
343341
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
344342
)
345343

346-
@pytest.mark.asyncio
347344
async def test_determine_grant_type_client_credentials(self):
348345
"""Test grant type determination for client credentials."""
349346
flows = OAuthFlows(
@@ -360,7 +357,6 @@ async def test_determine_grant_type_client_credentials(self):
360357

361358
assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS
362359

363-
@pytest.mark.asyncio
364360
async def test_determine_grant_type_openid_connect(self):
365361
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
366362
scheme = OpenIdConnectWithConfig(

0 commit comments

Comments
 (0)