diff --git a/examples/snowflake_native_app_example.py b/examples/snowflake_native_app_example.py new file mode 100644 index 00000000..82781ab5 --- /dev/null +++ b/examples/snowflake_native_app_example.py @@ -0,0 +1,47 @@ +import snowflake.connector # type: ignore + +from contextual import ContextualAI + +SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app' +BASE_URL = f'https://{SF_BASE_URL}/v1' + +SAMPLE_MESSAGE = 'Can you tell me about XYZ' + +ctx = snowflake.connector.connect( # type: ignore + user="",# snowflake account user + password='', # snowflake account password + account="organization-account", # snowflake organization and account - + session_parameters={ + 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json' + }) + +# Obtain a session token. +token_data = ctx._rest._token_request('ISSUE') # type: ignore +token_extract = token_data['data']['sessionToken'] # type: ignore + +# Create a request to the ingress endpoint with authz. +api_key = f'\"{token_extract}\"' + +client = ContextualAI(api_key=api_key, base_url=BASE_URL) + +agents = [a for a in client.agents.list() ] + +agent = agents[0] if agents else None + +if agent is None: + print('No agents found') + exit() +print(f"Found agent {agent.name} with id {agent.id}") + +messages = [ + { + 'content': SAMPLE_MESSAGE, + 'role': 'user', + } +] + +res = client.agents.query.create(agent.id, messages=messages) # type: ignore + +output = res.message.content # type: ignore + +print(output) \ No newline at end of file diff --git a/src/contextual/_client.py b/src/contextual/_client.py index 56cd7c6b..0aa4bcb5 100644 --- a/src/contextual/_client.py +++ b/src/contextual/_client.py @@ -56,6 +56,7 @@ class ContextualAI(SyncAPIClient): # client options api_key: str + is_snowflake: bool def __init__( self, @@ -97,6 +98,11 @@ def __init__( if base_url is None: base_url = f"https://api.contextual.ai/v1" + if 'snowflakecomputing.app' in str(base_url): + self.is_snowflake = True + else: + self.is_snowflake = False + super().__init__( version=__version__, base_url=base_url, @@ -123,7 +129,10 @@ def qs(self) -> Querystring: @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key - return {"Authorization": f"Bearer {api_key}"} + if self.is_snowflake: + return {"Authorization": f"Snowflake Token={api_key}"} + else: + return {"Authorization": f"Bearer {api_key}"} @property @override @@ -228,6 +237,7 @@ class AsyncContextualAI(AsyncAPIClient): # client options api_key: str + is_snowflake: bool def __init__( self, @@ -269,6 +279,11 @@ def __init__( if base_url is None: base_url = f"https://api.contextual.ai/v1" + if 'snowflakecomputing.app' in str(base_url): + self.is_snowflake = True + else: + self.is_snowflake = False + super().__init__( version=__version__, base_url=base_url, @@ -295,7 +310,10 @@ def qs(self) -> Querystring: @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key - return {"Authorization": f"Bearer {api_key}"} + if self.is_snowflake: + return {"Authorization": f"Snowflake Token={api_key}"} + else: + return {"Authorization": f"Bearer {api_key}"} @property @override