diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index e8eee1f034..212f7de4f7 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -23,6 +23,7 @@ from gapic import generator from gapic.schema import api from gapic.utils import Options +from gapic.utils.cache import generation_cache_context @click.command() @@ -56,15 +57,23 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None: [p.package for p in req.proto_file if p.name in req.file_to_generate] ).rstrip(".") - # Build the API model object. - # This object is a frozen representation of the whole API, and is sent - # to each template in the rendering step. - api_schema = api.API.build(req.proto_file, opts=opts, package=package) + # Create the generation cache context. + # This provides the shared storage for the @cached_proto_context decorator. + # 1. Performance: Memoizes `with_context` calls, speeding up generation significantly. + # 2. Safety: The decorator uses this storage to "pin" Proto objects in memory. + # This prevents Python's Garbage Collector from deleting objects created during + # `API.build` while `Generator.get_response` is still using their IDs. + # (See `gapic.utils.cache.cached_proto_context` for the specific pinning logic). + with generation_cache_context(): + # Build the API model object. + # This object is a frozen representation of the whole API, and is sent + # to each template in the rendering step. + api_schema = api.API.build(req.proto_file, opts=opts, package=package) - # Translate into a protobuf CodeGeneratorResponse; this reads the - # individual templates and renders them. - # If there are issues, error out appropriately. - res = generator.Generator(opts).get_response(api_schema, opts) + # Translate into a protobuf CodeGeneratorResponse; this reads the + # individual templates and renders them. + # If there are issues, error out appropriately. + res = generator.Generator(opts).get_response(api_schema, opts) # Output the serialized response. output.write(res.SerializeToString()) diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index 7df8d0291f..8defa70dc2 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -35,6 +35,7 @@ from gapic.schema import imp from gapic.schema import naming from gapic.utils import cached_property +from gapic.utils import cached_proto_context from gapic.utils import RESERVED_NAMES # This class is a minor hack to optimize Address's __eq__ method. @@ -359,6 +360,7 @@ def resolve(self, selector: str) -> str: return f'{".".join(self.package)}.{selector}' return selector + @cached_proto_context def with_context(self, *, collisions: Set[str]) -> "Address": """Return a derivative of this address with the provided context. @@ -398,6 +400,7 @@ def doc(self): return "\n\n".join(self.documentation.leading_detached_comments) return "" + @cached_proto_context def with_context(self, *, collisions: Set[str]) -> "Metadata": """Return a derivative of this metadata with the provided context. diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 17a7832756..3c3c3d5849 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -67,6 +67,7 @@ from gapic import utils from gapic.schema import metadata +from gapic.utils import cached_proto_context from gapic.utils import uri_sample from gapic.utils import make_private @@ -410,6 +411,7 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]: "This code should not be reachable; please file a bug." ) + @cached_proto_context def with_context( self, *, @@ -805,6 +807,7 @@ def get_field( # message. return cursor.message.get_field(*field_path[1:], collisions=collisions) + @cached_proto_context def with_context( self, *, @@ -937,6 +940,7 @@ def ident(self) -> metadata.Address: """Return the identifier data to be used in templates.""" return self.meta.address + @cached_proto_context def with_context(self, *, collisions: Set[str]) -> "EnumType": """Return a derivative of this enum with the provided context. @@ -1058,6 +1062,7 @@ class ExtendedOperationInfo: request_type: MessageType operation_type: MessageType + @cached_proto_context def with_context( self, *, @@ -1127,6 +1132,7 @@ class OperationInfo: response_type: MessageType metadata_type: MessageType + @cached_proto_context def with_context( self, *, @@ -1937,6 +1943,7 @@ def void(self) -> bool: """Return True if this method has no return value, False otherwise.""" return self.output.ident.proto == "google.protobuf.Empty" + @cached_proto_context def with_context( self, *, @@ -2357,6 +2364,7 @@ def operation_polling_method(self) -> Optional[Method]: def is_internal(self) -> bool: return any(m.is_internal for m in self.methods.values()) + @cached_proto_context def with_context( self, *, diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 8b48801730..23c5739156 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from gapic.utils.cache import cached_property +from gapic.utils.cache import cached_proto_context from gapic.utils.case import to_snake_case from gapic.utils.case import to_camel_case from gapic.utils.checks import is_msg_field_pb @@ -34,6 +35,7 @@ __all__ = ( "cached_property", + "cached_proto_context", "convert_uri_fieldnames", "doc", "empty", diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index f9c4d703f5..5f9b0f658d 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -13,6 +13,8 @@ # limitations under the License. import functools +import contextlib +import threading def cached_property(fx): @@ -43,3 +45,91 @@ def inner(self): return self._cached_values[fx.__name__] return property(inner) + + +# Thread-local storage for the simple cache dictionary. +# This ensures that parallel generation tasks (if any) do not corrupt each other's cache. +_thread_local = threading.local() + + +@contextlib.contextmanager +def generation_cache_context(): + """Context manager to explicitly manage the lifecycle of the generation cache. + + This manager initializes a fresh dictionary in thread-local storage when entering + the context and strictly deletes it when exiting. + + **Memory Management:** + The cache stores strong references to Proto objects to "pin" them in memory + (see `cached_proto_context`). It is critical that this context manager deletes + the dictionary in the `finally` block. Deleting the dictionary breaks the + reference chain, allowing Python's Garbage Collector to finally free all the + large Proto objects that were pinned during generation. + """ + # Initialize the cache as a standard dictionary. + _thread_local.cache = {} + try: + yield + finally: + # Delete the dictionary to free all memory and pinned objects. + # This is essential to prevent memory leaks in long-running processes. + del _thread_local.cache + + +def cached_proto_context(func): + """Decorator to memoize `with_context` calls based on object identity and collisions. + + This mechanism provides a significant performance boost by preventing + redundant recalculations of naming collisions during template rendering. + + Since the Proto wrapper objects are unhashable (mutable), we use `id(self)` as + the primary cache key. Normally, this is dangerous: if the object is garbage + collected, Python might reuse its memory address for a *new* object, leading to + a cache collision (the "Zombie ID" bug). + + To prevent this, this decorator stores the value as a tuple: `(result, self)`. + By keeping a reference to `self` in the cache value, we "pin" the object in + memory. This forces the Garbage Collector to keep the object alive, guaranteeing + that `id(self)` remains unique for the entire lifespan of the `generation_cache_context`. + + Args: + func (Callable): The function to decorate (usually `with_context`). + + Returns: + Callable: The wrapped function with caching and pinning logic. + """ + + @functools.wraps(func) + def wrapper(self, *, collisions, **kwargs): + + # 1. Check for active cache (returns None if context is not active) + context_cache = getattr(_thread_local, "cache", None) + + # If we are not inside a generation_cache_context (e.g. unit tests), + # bypass the cache entirely. + if context_cache is None: + return func(self, collisions=collisions, **kwargs) + + # 2. Create the cache key + # We use frozenset for collisions to make it hashable. + # We use id(self) because 'self' is not hashable. + collisions_key = frozenset(collisions) if collisions else None + key = (id(self), collisions_key) + + # 3. Check Cache + if key in context_cache: + # The cache stores (result, pinned_object). We return just the result. + return context_cache[key][0] + + # 4. Execute the actual function + # We ensure context_cache is passed down to the recursive calls + result = func(self, collisions=collisions, **kwargs) + + # 5. Update Cache & Pin Object + # We store (result, self). The reference to 'self' prevents garbage collection, + # ensuring that 'id(self)' cannot be reused for a new object while this + # cache entry exists. + context_cache[key] = (result, self) + return result + + return wrapper diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index eee6865e6d..ec7f82438b 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -31,3 +31,43 @@ def bar(self): assert foo.call_count == 1 assert foo.bar == 42 assert foo.call_count == 1 + + +def test_cached_proto_context(): + class Foo: + def __init__(self): + self.call_count = 0 + + # We define a signature that matches the real Proto.with_context + # to ensure arguments are propagated correctly. + @cache.cached_proto_context + def with_context(self, collisions, *, skip_fields=False, visited_messages=None): + self.call_count += 1 + return f"val-{self.call_count}" + + foo = Foo() + + # 1. Test Bypass (No Context) + # The cache is not active, so every call increments the counter. + assert foo.with_context(collisions={"a"}) == "val-1" + assert foo.with_context(collisions={"a"}) == "val-2" + + # 2. Test Context Activation + with cache.generation_cache_context(): + # Reset counter to make tracking easier + foo.call_count = 0 + + # A. Basic Cache Hit + assert foo.with_context(collisions={"a"}) == "val-1", "a" + assert foo.with_context(collisions={"a"}) == "val-1" # Hit + assert foo.call_count == 1 + + # B. Collision Difference + # Changing collisions creates a new key + assert foo.with_context(collisions={"b"}) == "val-2" + assert foo.call_count == 2 + + # 3. Context Cleared + # Everything should be forgotten now. + assert getattr(cache._thread_local, "cache", None) is None + assert foo.with_context(collisions={"a"}) == "val-3"