Skip to content

Commit c6e8dbe

Browse files
gaogaotiantianzhengruifeng
authored andcommitted
[SPARK-54247][PYTHON][FOLLOWUP] Ensure socket is not leaked and add regression tests
### What changes were proposed in this pull request? * #53200 did not fix the socket leak issue properly. This PR fixed it by closing the socket as well * In order to prevent such mistakes, regression tests are added * `_local_iterator_from_socket` is leaking socket too, fixed in this PR as well. ### Why are the changes needed? To avoid leaking socket (technically it's still closed at GC, it was explained in #53200). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new regression test is added and I confirmed the test fails before fix and passes after. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53203 from gaogaotiantian/really-fix-socket-leak. Authored-by: Tian Gao <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 7c0b392 commit c6e8dbe

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pydoc
2121
import shutil
2222
import tempfile
23+
import warnings
2324
import unittest
2425
from typing import cast
2526
import io
@@ -1044,6 +1045,18 @@ def test_local_checkpoint_dataframe_with_storage_level(self):
10441045
df = self.spark.range(10).localCheckpoint(eager=True, storageLevel=StorageLevel.DISK_ONLY)
10451046
df.collect()
10461047

1048+
def test_socket_leak(self):
1049+
with warnings.catch_warnings(record=True) as w:
1050+
warnings.simplefilter("always", ResourceWarning)
1051+
df = self.spark.range(10)
1052+
df.collect()
1053+
1054+
df = self.spark.range(10)
1055+
for _ in df.toLocalIterator():
1056+
pass
1057+
1058+
self.assertEqual(w, [])
1059+
10471060
def test_transpose(self):
10481061
df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}])
10491062

python/pyspark/util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ class PythonEvalType:
668668
SQL_ARROW_UDTF: "SQLArrowUDTFType" = 302
669669

670670

671-
def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair":
671+
def _create_local_socket(sock_info: "JavaArray") -> Tuple["io.BufferedRWPair", "socket.socket"]:
672672
"""
673673
Create a local socket that can be used to load deserialized data from the JVM
674674
@@ -689,7 +689,7 @@ def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair":
689689
# The RDD materialization time is unpredictable, if we set a timeout for socket reading
690690
# operation, it will very possibly fail. See SPARK-18281.
691691
sock.settimeout(None)
692-
return sockfile
692+
return sockfile, sock
693693

694694

695695
@contextmanager
@@ -710,10 +710,11 @@ def _load_from_socket(sock_info: "JavaArray", serializer: "Serializer") -> Itera
710710
usually a generator that yields deserialized data
711711
"""
712712
try:
713-
sockfile = _create_local_socket(sock_info)
713+
sockfile, sock = _create_local_socket(sock_info)
714714
yield serializer.load_stream(sockfile)
715715
finally:
716716
sockfile.close()
717+
sock.close()
717718

718719

719720
def _local_iterator_from_socket(sock_info: "JavaArray", serializer: "Serializer") -> Iterator[Any]:
@@ -725,7 +726,7 @@ def __init__(self, _sock_info: "JavaArray", _serializer: "Serializer"):
725726
auth_secret: str
726727
jsocket_auth_server: "JavaObject"
727728
port, auth_secret, self.jsocket_auth_server = _sock_info
728-
self._sockfile = _create_local_socket((port, auth_secret))
729+
self._sockfile, self._sock = _create_local_socket((port, auth_secret))
729730
self._serializer = _serializer
730731
self._read_iter: Iterator[Any] = iter([]) # Initialize as empty iterator
731732
self._read_status = 1
@@ -761,6 +762,8 @@ def __del__(self) -> None:
761762
except Exception:
762763
# Ignore any errors, socket is automatically closed when garbage-collected
763764
pass
765+
self._sockfile.close()
766+
self._sock.close()
764767

765768
return iter(PyLocalIterable(sock_info, serializer))
766769

0 commit comments

Comments
 (0)