diff --git a/.gitignore b/.gitignore index 6d83cfc..0a96db7 100755 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ MANIFEST .idea pydgraph.iml +# VS Code +.vscode + # Python Virtual Environments venv .venv diff --git a/README.md b/README.md index 708025f..6f324e5 100644 --- a/README.md +++ b/README.md @@ -348,6 +348,55 @@ request = txn.create_request(mutations=[mutation], commit_now=True) txn.do_request(request) ``` +### Committing a Transaction + +A transaction can be committed using the `Txn#commit()` method. If your transaction +consist solely of `Txn#query` or `Txn#queryWithVars` calls, and no calls to +`Txn#mutate`, then calling `Txn#commit()` is not necessary. + +An error is raised if another transaction(s) modify the same data concurrently that was +modified in the current transaction. It is up to the user to retry transactions +when they fail. + +```python +txn = client.txn() +try: + # ... + # Perform any number of queries and mutations + # ... + # and finally... + txn.commit() +except pydgraph.AbortedError: + # Retry or handle exception. +finally: + # Clean up. Calling this after txn.commit() is a no-op + # and hence safe. + txn.discard() +``` + +#### Using Transaction with Context Manager + +The Python context manager will automatically perform the "`commit`" action +after all queries and mutations have been done, and perform "`discard`" action +to clean the transaction. +When something goes wrong in the scope of context manager, "`commit`" will not +be called,and the "`discard`" action will be called to drop any potential changes. + +```python +with client.begin(read_only=False, best_effort=False) as txn: + # Do some queries or mutations here +``` + +or you can directly create a transaction from the `Txn` class. + +```python +with pydgraph.Txn(client, read_only=False, best_effort=False) as txn: + # Do some queries or mutations here +``` + +> `client.begin()` can only be used with "`with-as`" blocks, while `pydgraph.Txn` class can be directly called to instantiate a transaction object. + + ### Running a Query You can run a query by calling `Txn#query(string)`. You will need to pass in a @@ -506,6 +555,28 @@ stub1.close() stub2.close() ``` +#### Use context manager to automatically clean resources + +Use function call: + +```python +with pydgraph.client_stub(SERVER_ADDR) as stub1: + with pydgraph.client_stub(SERVER_ADDR) as stub2: + client = pydgraph.DgraphClient(stub1, stub2) +``` + +Use class constructor: + +```python +with pydgraph.DgraphClientStub(SERVER_ADDR) as stub1: + with pydgraph.DgraphClientStub(SERVER_ADDR) as stub2: + client = pydgraph.DgraphClient(stub1, stub2) +``` + +Note: `client` should be used inside the "`with-as`" block. The resources related to +`client` will be automatically released outside the block and `client` is not usable +any more. + ### Setting Metadata Headers Metadata headers such as authentication tokens can be set through the metadata of gRPC methods. diff --git a/pydgraph/client.py b/pydgraph/client.py index 47bcf27..f43f4f6 100755 --- a/pydgraph/client.py +++ b/pydgraph/client.py @@ -3,6 +3,7 @@ """Dgraph python client.""" +import contextlib import random import urllib.parse @@ -154,9 +155,9 @@ def handle_alter_future(future): except Exception as error: DgraphClient._common_except_alter(error) - def txn(self, read_only=False, best_effort=False): + def txn(self, read_only=False, best_effort=False, **commit_kwargs): """Creates a transaction.""" - return txn.Txn(self, read_only=read_only, best_effort=best_effort) + return txn.Txn(self, read_only=read_only, best_effort=best_effort, **commit_kwargs) def any_client(self): """Returns a random gRPC client so that requests are distributed evenly among them.""" @@ -173,6 +174,26 @@ def close(self): for client in self._clients: client.close() + @contextlib.contextmanager + def begin(self, + read_only:bool=False, best_effort:bool=False, + timeout = None, metadata = None, credentials = None): + '''Start a managed transaction. + + Note + ---- + Only use this function in ``with-as`` blocks. + ''' + tx = self.txn(read_only=read_only, best_effort=best_effort) + try: + yield tx + if read_only == False and tx._finished == False: + tx.commit(timeout=timeout, metadata=metadata, credentials=credentials) + except Exception as e: + raise e + finally: + tx.discard() + def open(connection_string: str) -> DgraphClient: """Open a new Dgraph client. Use client.close() to close the client. diff --git a/pydgraph/client_stub.py b/pydgraph/client_stub.py index da432db..4137adb 100644 --- a/pydgraph/client_stub.py +++ b/pydgraph/client_stub.py @@ -3,6 +3,7 @@ """Stub for RPC request.""" +import contextlib import grpc from pydgraph.meta import VERSION @@ -29,6 +30,14 @@ def __init__(self, addr="localhost:9080", credentials=None, options=None): self.channel = grpc.secure_channel(addr, credentials, options) self.stub = api_grpc.DgraphStub(self.channel) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + if exc_type is not None: + raise exc_val def login(self, login_req, timeout=None, metadata=None, credentials=None): return self.stub.Login( @@ -118,3 +127,27 @@ def from_cloud(cloud_endpoint, api_key, options=None): options=options, ) return client_stub + +@contextlib.contextmanager +def client_stub(addr='localhost:9080', **kwargs): + """ Create a managed DgraphClientStub instance. + + Parameters + ---------- + addr : str, optional + credentials : ChannelCredentials, optional + options: List[Dict] + An optional list of key-value pairs (``channel_arguments`` + in gRPC Core runtime) to configure the channel. + + Note + ---- + Only use this function in ``with-as`` blocks. + """ + stub = DgraphClientStub(addr=addr, **kwargs) + try: + yield stub + except Exception as e: + raise e + finally: + stub.close() \ No newline at end of file diff --git a/pydgraph/txn.py b/pydgraph/txn.py index c5186ba..aca0912 100644 --- a/pydgraph/txn.py +++ b/pydgraph/txn.py @@ -30,7 +30,8 @@ class Txn(object): after calling commit. """ - def __init__(self, client, read_only=False, best_effort=False): + def __init__(self, client, read_only=False, best_effort=False, + timeout=None, metadata=None, credentials=None): if not read_only and best_effort: raise Exception( "Best effort transactions are only compatible with " @@ -45,6 +46,23 @@ def __init__(self, client, read_only=False, best_effort=False): self._mutated = False self._read_only = read_only self._best_effort = best_effort + self._commit_kwargs = { + "timeout": timeout, + "metadata": metadata, + "credentials": credentials + } + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.discard(**self._commit_kwargs) + raise exc_val + if self._read_only == False and self._finished == False: + self.commit(**self._commit_kwargs) + else: + self.discard(**self._commit_kwargs) def query( self, @@ -201,7 +219,7 @@ def handle_query_future(future): try: response = future.result() except Exception as error: - txn._common_except_mutate(error) + Txn._common_except_mutate(error) return response @@ -212,11 +230,11 @@ def handle_mutate_future(txn, future, commit_now): response = future.result() except Exception as error: try: - txn.discard(timeout=timeout, metadata=metadata, credentials=credentials) + txn.discard(**txn._commit_kwargs) except: # Ignore error - user should see the original error. pass - txn._common_except_mutate(error) + Txn._common_except_mutate(error) if commit_now: txn._finished = True diff --git a/tests/test_acct_upsert.py b/tests/test_acct_upsert.py index 49fca46..1212a5c 100644 --- a/tests/test_acct_upsert.py +++ b/tests/test_acct_upsert.py @@ -15,7 +15,7 @@ import pydgraph -from . import helper +from tests import helper CONCURRENCY = 5 FIRSTS = ["Paul", "Eric", "Jack", "John", "Martin"] diff --git a/tests/test_acl.py b/tests/test_acl.py index 7c4d7f6..ffd8f32 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -13,6 +13,7 @@ import unittest from . import helper +import pydgraph @unittest.skipIf(shutil.which("dgraph") is None, "Dgraph binary not found.") diff --git a/tests/test_async.py b/tests/test_async.py index 645a76a..027c1e4 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -10,7 +10,7 @@ import pydgraph -from . import helper +from tests import helper class TestAsync(helper.ClientIntegrationTestCase): diff --git a/tests/test_client_stub.py b/tests/test_client_stub.py index 94abe5f..a3d3f91 100644 --- a/tests/test_client_stub.py +++ b/tests/test_client_stub.py @@ -70,10 +70,130 @@ def test_from_cloud(self): raise (e) +class TestDgraphClientStubContextManager(helper.ClientIntegrationTestCase): + def setUp(self): + super(TestDgraphClientStubContextManager, self).setUp() + + def test_context_manager(self): + """Test basic context manager usage for DgraphClientStub.""" + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + def test_context_manager_code_exception(self): + """Test that exceptions within context manager are properly handled.""" + with self.assertRaises(AttributeError): + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + self.check_version(client_stub) # AttributeError: no such method + + def test_context_manager_function_wrapper(self): + """Test the client_stub() function wrapper for context manager.""" + with pydgraph.client_stub(addr=self.TEST_SERVER_ADDR) as client_stub: + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + def test_context_manager_closes_stub(self): + """Test that the stub is properly closed after exiting context manager.""" + stub = None + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + stub = client_stub + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + # After exiting context, stub should be closed and unusable + with self.assertRaises(Exception): + stub.check_version(pydgraph.Check()) + + def test_context_manager_with_client(self): + """Test using DgraphClientStub context manager with DgraphClient.""" + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + client = pydgraph.DgraphClient(client_stub) + + # Perform a simple operation + txn = client.txn(read_only=True) + query = "{ me(func: has(name)) { name } }" + resp = txn.query(query) + self.assertIsNotNone(resp) + + def test_context_manager_exception_still_closes(self): + """Test that stub is closed even when an exception occurs.""" + stub_ref = None + try: + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + stub_ref = client_stub + client_stub.check_version(pydgraph.Check()) + raise ValueError("Test exception") + except ValueError: + pass + + # Stub should still be closed despite the exception + with self.assertRaises(Exception): + stub_ref.check_version(pydgraph.Check()) + + def test_context_manager_function_wrapper_closes(self): + """Test that client_stub() function wrapper properly closes the stub.""" + stub_ref = None + with pydgraph.client_stub(addr=self.TEST_SERVER_ADDR) as client_stub: + stub_ref = client_stub + ver = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver) + + # Stub should be closed after exiting + with self.assertRaises(Exception): + stub_ref.check_version(pydgraph.Check()) + + def test_context_manager_multiple_operations(self): + """Test performing multiple operations within context manager.""" + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as client_stub: + # Check version multiple times + ver1 = client_stub.check_version(pydgraph.Check()) + ver2 = client_stub.check_version(pydgraph.Check()) + self.assertIsNotNone(ver1) + self.assertIsNotNone(ver2) + + # Create client and perform operations + client = pydgraph.DgraphClient(client_stub) + txn = client.txn(read_only=True) + query = "{ me(func: has(name)) { name } }" + resp = txn.query(query) + self.assertIsNotNone(resp) + + def test_context_manager_nested_with_client_operations(self): + """Test full workflow: stub context manager with client and transaction operations.""" + with pydgraph.DgraphClientStub(addr=self.TEST_SERVER_ADDR) as stub: + client = pydgraph.DgraphClient(stub) + + # Set schema + schema = "test_name: string @index(fulltext) ." + op = pydgraph.Operation(schema=schema) + client.alter(op) + + # Perform mutation and query + with client.txn() as txn: + response = txn.mutate(set_obj={"test_name": "ContextManagerTest"}) + self.assertEqual(1, len(response.uids)) + uid = list(response.uids.values())[0] + + # Verify data was committed + query = '''{{ + me(func: uid("{uid}")) {{ + test_name + }} + }}'''.format(uid=uid) + + with client.txn(read_only=True) as txn: + resp = txn.query(query) + import json + results = json.loads(resp.json).get("me") + self.assertEqual([{"test_name": "ContextManagerTest"}], results) + + + def suite(): """Returns a test suite object.""" suite_obj = unittest.TestSuite() suite_obj.addTest(TestDgraphClientStub()) + suite_obj.addTest(TestDgraphClientStubContextManager()) return suite_obj diff --git a/tests/test_essentials.py b/tests/test_essentials.py index 2ee5f2a..6f06b5c 100644 --- a/tests/test_essentials.py +++ b/tests/test_essentials.py @@ -10,7 +10,7 @@ import logging import unittest -from . import helper +from tests import helper class TestEssentials(helper.ClientIntegrationTestCase): diff --git a/tests/test_queries.py b/tests/test_queries.py index d8684c9..80781ce 100755 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -13,7 +13,7 @@ import pydgraph -from . import helper +from tests import helper class TestQueries(helper.ClientIntegrationTestCase): diff --git a/tests/test_txn.py b/tests/test_txn.py index bff027b..8c2b389 100644 --- a/tests/test_txn.py +++ b/tests/test_txn.py @@ -11,7 +11,7 @@ import pydgraph -from . import helper +from tests import helper class TestTxn(helper.ClientIntegrationTestCase): @@ -610,10 +610,193 @@ def test_sp_star2(self): self.assertEqual([{"uid": uid1}], json.loads(resp.json).get("me")) +class TestContextManager(helper.ClientIntegrationTestCase): + def setUp(self): + super(TestContextManager, self).setUp() + helper.drop_all(self.client) + helper.set_schema(self.client, "name: string @index(fulltext) .") + + def test_context_manager_by_contextlib(self): + """Test context manager via client.begin() for read-only transactions.""" + q = ''' + { + company(func: type(x.Company), first: 10){ + expand(_all_) + } + } + ''' + with self.client.begin(read_only=True, best_effort=True) as tx: + response = tx.query(q) + self.assertIsNotNone(response) + data = json.loads(response.json) + + def test_context_manager_by_class(self): + """Test context manager using Txn class directly for read-only transactions.""" + q = ''' + { + company(func: type(x.Company), first: 10){ + expand(_all_) + } + } + ''' + with pydgraph.Txn(self.client, read_only=True, best_effort=True) as tx: + response = tx.query(q) + self.assertIsNotNone(response) + data = json.loads(response.json) + + def test_context_manager_auto_commit(self): + """Test that write transactions automatically commit on successful completion.""" + with self.client.txn() as txn: + response = txn.mutate(set_obj={"name": "Alice"}) + self.assertEqual(1, len(response.uids), "Nothing was assigned") + uid = list(response.uids.values())[0] + + # Verify the data was committed by querying in a new transaction + query = '''{{ + me(func: uid("{uid}")) {{ + name + }} + }}'''.format(uid=uid) + + resp = self.client.txn(read_only=True).query(query) + self.assertEqual([{"name": "Alice"}], json.loads(resp.json).get("me")) + + def test_context_manager_read_only_auto_discard(self): + """Test that read-only transactions automatically discard.""" + # Create some data first + txn = self.client.txn() + response = txn.mutate(set_obj={"name": "Bob"}) + uid = list(response.uids.values())[0] + txn.commit() + + # Read-only transaction should auto-discard (not commit) + query = '''{{ + me(func: uid("{uid}")) {{ + name + }} + }}'''.format(uid=uid) + + with self.client.txn(read_only=True) as txn: + resp = txn.query(query) + self.assertEqual([{"name": "Bob"}], json.loads(resp.json).get("me")) + + # Transaction should be finished after context manager exits + self.assertTrue(txn._finished) + + def test_context_manager_exception_handling(self): + """Test that exceptions cause automatic discard and are re-raised.""" + with self.assertRaises(ValueError): + with self.client.txn() as txn: + response = txn.mutate(set_obj={"name": "Charlie"}) + uid = list(response.uids.values())[0] + raise ValueError("Test exception") + + # Verify transaction was discarded - data should not exist + query = '''{{ + me(func: has(name)) {{ + name + }} + }}''' + + resp = self.client.txn(read_only=True).query(query) + results = json.loads(resp.json).get("me") + # Should be empty or not contain Charlie + if results: + names = [r.get("name") for r in results] + self.assertNotIn("Charlie", names) + + def test_context_manager_transaction_finished_after_exit(self): + """Test that transaction is marked as finished after exiting context manager.""" + with self.client.txn() as txn: + txn.mutate(set_obj={"name": "David"}) + self.assertFalse(txn._finished) + + # Should be finished after exit + self.assertTrue(txn._finished) + + # Should not be able to use transaction after context manager + with self.assertRaises(Exception): + txn.query("{ me() {} }") + + def test_context_manager_multiple_mutations(self): + """Test multiple mutations within a single context manager.""" + with self.client.txn() as txn: + response1 = txn.mutate(set_obj={"name": "Eve"}) + uid1 = list(response1.uids.values())[0] + + response2 = txn.mutate(set_obj={"name": "Frank"}) + uid2 = list(response2.uids.values())[0] + + # Verify both mutations were committed + query = '''{{ + me(func: has(name), orderasc: name) {{ + name + }} + }}''' + + resp = self.client.txn(read_only=True).query(query) + results = json.loads(resp.json).get("me") + names = [r.get("name") for r in results] + self.assertIn("Eve", names) + self.assertIn("Frank", names) + + def test_context_manager_query_and_mutate(self): + """Test both query and mutate operations within a context manager.""" + # Create initial data + txn = self.client.txn() + response = txn.mutate(set_obj={"name": "Grace"}) + uid = list(response.uids.values())[0] + txn.commit() + + # Query and update in context manager + with self.client.txn() as txn: + query = '''{{ + me(func: uid("{uid}")) {{ + name + }} + }}'''.format(uid=uid) + + resp = txn.query(query) + self.assertEqual([{"name": "Grace"}], json.loads(resp.json).get("me")) + + # Update the name + txn.mutate(set_obj={"uid": uid, "name": "Grace Updated"}) + + # Verify the update was committed + resp = self.client.txn(read_only=True).query(query) + self.assertEqual([{"name": "Grace Updated"}], json.loads(resp.json).get("me")) + + def test_context_manager_invalid_nquad_exception(self): + """Test that invalid operations cause proper exception handling and discard.""" + with self.assertRaises(Exception): + with self.client.txn() as txn: + # This should fail with invalid N-Quad syntax + txn.mutate(set_nquads="_:node InvalidWithoutQuotes") + + # Transaction should be finished + self.assertTrue(txn._finished) + + def test_context_manager_read_only_cannot_mutate(self): + """Test that read-only transactions cannot mutate within context manager.""" + with self.assertRaises(Exception): + with self.client.txn(read_only=True) as txn: + txn.mutate(set_obj={"name": "Should Fail"}) + + def test_context_manager_no_mutations_auto_commit(self): + """Test that transactions with no mutations don't error on auto-commit.""" + with self.client.txn() as txn: + query = "{ me(func: has(name)) { name } }" + resp = txn.query(query) + + # Should complete without errors even though no mutations were made + self.assertTrue(txn._finished) + + def suite(): s = unittest.TestSuite() s.addTest(TestTxn()) s.addTest(TestSPStar()) + s.addTest(TestContextManager()) return s diff --git a/tests/test_upsert_block.py b/tests/test_upsert_block.py index 2629462..3548317 100644 --- a/tests/test_upsert_block.py +++ b/tests/test_upsert_block.py @@ -9,7 +9,7 @@ import logging import unittest -from . import helper +from tests import helper class TestUpsertBlock(helper.ClientIntegrationTestCase):