Skip to content

Commit c9b76e7

Browse files
author
Douglas Blank
committed
Address Copilot review
1 parent 0463cc2 commit c9b76e7

File tree

2 files changed

+25
-35
lines changed

2 files changed

+25
-35
lines changed

comet_for_mlflow/comet_for_mlflow.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -130,28 +130,13 @@ def __init__(
130130
try:
131131
self.store = _get_store(mlflow_store_uri)
132132
except RestException as e:
133-
# Check HTTP status code for authentication errors
134-
status_code = (
135-
e.get_http_status_code() if hasattr(e, "get_http_status_code") else None
136-
)
137-
error_msg = str(e)
138-
if (
139-
status_code == 401
140-
or "401" in error_msg
141-
or "Credential" in error_msg
142-
or "authentication" in error_msg.lower()
143-
):
133+
if self._is_authentication_error(e):
144134
self._log_authentication_error(
145135
mlflow_store_uri, "connecting to MLflow store"
146136
)
147137
raise
148138
except Exception as e:
149-
error_msg = str(e)
150-
if (
151-
"401" in error_msg
152-
or "Credential" in error_msg
153-
or "authentication" in error_msg.lower()
154-
):
139+
if self._is_authentication_error(e):
155140
self._log_authentication_error(
156141
mlflow_store_uri, "connecting to MLflow store"
157142
)
@@ -165,28 +150,13 @@ def __init__(
165150
try:
166151
self.mlflow_experiments = search_mlflow_store_experiments(self.store)
167152
except RestException as e:
168-
# Check HTTP status code for authentication errors
169-
status_code = (
170-
e.get_http_status_code() if hasattr(e, "get_http_status_code") else None
171-
)
172-
error_msg = str(e)
173-
if (
174-
status_code == 401
175-
or "401" in error_msg
176-
or "Credential" in error_msg
177-
or "authentication" in error_msg.lower()
178-
):
153+
if self._is_authentication_error(e):
179154
self._log_authentication_error(
180155
mlflow_store_uri, "accessing MLflow experiments"
181156
)
182157
raise
183158
except Exception as e:
184-
error_msg = str(e)
185-
if (
186-
"401" in error_msg
187-
or "Credential" in error_msg
188-
or "authentication" in error_msg.lower()
189-
):
159+
if self._is_authentication_error(e):
190160
self._log_authentication_error(
191161
mlflow_store_uri, "accessing MLflow experiments"
192162
)
@@ -722,6 +692,22 @@ def get_api_key_or_login(self, api_key):
722692

723693
return (api_key, None)
724694

695+
def _is_authentication_error(self, exception):
696+
"""Check if an exception is an authentication error (401)."""
697+
error_msg = str(exception)
698+
status_code = None
699+
700+
# Check HTTP status code for RestException
701+
if hasattr(exception, "get_http_status_code"):
702+
status_code = exception.get_http_status_code()
703+
704+
return (
705+
status_code == 401
706+
or "401" in error_msg
707+
or "Credential" in error_msg
708+
or "authentication" in error_msg.lower()
709+
)
710+
725711
def _log_authentication_error(self, mlflow_store_uri, context):
726712
"""Log helpful error message for MLflow authentication errors."""
727713
LOGGER.error("")

examples/keras-example/run.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from tensorflow.keras.models import Sequential
1212
from tensorflow.keras.preprocessing.text import Tokenizer
1313

14-
# When using when from databricks:
14+
# The following MLflow settings are specific to Databricks.
15+
# They should only be used when running in a Databricks environment.
16+
# If running outside Databricks, you may need to set a different tracking URI
17+
# and experiment, or remove these lines entirely.
18+
1519
mlflow.set_tracking_uri("databricks")
1620
mlflow.set_experiment("/Shared/keras_example")
1721

0 commit comments

Comments
 (0)