Skip to content

Commit

Permalink
patch server_cursor_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestjw committed Jan 14, 2025
1 parent c59b514 commit 3334e74
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory"


class PsycopgInstrumentor(BaseInstrumentor):
Expand Down Expand Up @@ -257,9 +258,17 @@ def instrument_connection(connection, tracer_provider=None):
setattr(
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
)
setattr(
connection,
_OTEL_SERVER_CURSOR_FACTORY_KEY,
connection.server_cursor_factory,
)
connection.cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
connection.server_cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
connection._is_instrumented_by_opentelemetry = True
else:
_logger.warning(
Expand All @@ -273,6 +282,9 @@ def uninstrument_connection(connection):
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)
connection.server_cursor_factory = getattr(
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None
)

return connection

Expand All @@ -293,6 +305,12 @@ def wrapped_connection(
kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs)
connection = connect_method(*args, **kwargs)
self.get_connection_attributes(connection)

connection.server_cursor_factory = _new_cursor_factory(
db_api=self,
base_factory=getattr(connection, "server_cursor_factory", None),
)

return connection


Expand All @@ -313,6 +331,11 @@ async def wrapped_connection(
)
connection = await connect_method(*args, **kwargs)
self.get_connection_attributes(connection)

connection.server_cursor_factory = _new_cursor_async_factory(
db_api=self,
base_factory=getattr(connection, "server_cursor_factory", None),
)
return connection


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import types
from typing import Optional
from unittest import IsolatedAsyncioTestCase, mock

import psycopg
Expand Down Expand Up @@ -83,10 +84,14 @@ class MockConnection:

def __init__(self, *args, **kwargs):
self.cursor_factory = kwargs.pop("cursor_factory", None)
self.server_cursor_factory = lambda _: MockCursor()

def cursor(self):
if self.cursor_factory:
def cursor(self, name: Optional[str] = None):
if not name and self.cursor_factory:
return self.cursor_factory(self)

if name and self.server_cursor_factory:
return self.server_cursor_factory(self)
return MockCursor()

def get_dsn_parameters(self): # pylint: disable=no-self-use
Expand All @@ -102,15 +107,18 @@ class MockAsyncConnection:

def __init__(self, *args, **kwargs):
self.cursor_factory = kwargs.pop("cursor_factory", None)
self.server_cursor_factory = lambda _: MockAsyncCursor()

@staticmethod
async def connect(*args, **kwargs):
return MockAsyncConnection(**kwargs)

def cursor(self):
if self.cursor_factory:
cur = self.cursor_factory(self)
return cur
def cursor(self, name: Optional[str] = None):
if not name and self.cursor_factory:
return self.cursor_factory(self)

if name and self.server_cursor_factory:
return self.server_cursor_factory(self)
return MockAsyncCursor()

def execute(self, query, params=None, *, prepare=None, binary=False):
Expand Down Expand Up @@ -197,6 +205,36 @@ def test_instrumentor(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_with_named_cursor(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.connect(database="test")

cursor = cnx.cursor(name="named_cursor")

query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

cnx = psycopg.connect(database="test")
cursor = cnx.cursor(name="named_cursor")
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrumentor_with_connection_class(self):
PsycopgInstrumentor().instrument()
Expand Down Expand Up @@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_with_connection_class_and_named_cursor(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.Connection.connect(database="test")

cursor = cnx.cursor(name="named_cursor")

query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

cnx = psycopg.Connection.connect(database="test")
cursor = cnx.cursor(name="named_cursor")
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_span_name(self):
PsycopgInstrumentor().instrument()

Expand Down Expand Up @@ -314,6 +382,23 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrument_connection_with_named_cursor(self):
cnx = psycopg.connect(database="test")
query = "SELECT * FROM test"
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 0)

cnx = PsycopgInstrumentor().instrument_connection(cnx)
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
def test_instrument_connection_with_instrument(self):
cnx = psycopg.connect(database="test")
Expand Down Expand Up @@ -368,6 +453,25 @@ def test_uninstrument_connection_with_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_uninstrument_connection_with_instrument_connection_and_named_cursor(
self,
):
cnx = psycopg.connect(database="test")
PsycopgInstrumentor().instrument_connection(cnx)
query = "SELECT * FROM test"
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

cnx = PsycopgInstrumentor().uninstrument_connection(cnx)
cursor = cnx.cursor(name="named_cursor")
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect")
def test_sqlcommenter_enabled(self, event_mocked):
cnx = psycopg.connect(database="test")
Expand Down Expand Up @@ -419,6 +523,33 @@ async def test_async_connection():
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_wrap_async_connection_class_with_named_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
async with cnx.cursor(name="named_cursor") as cursor:
await cursor.execute("SELECT * FROM test")

await test_async_connection()
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationScope(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()
Expand Down

0 comments on commit 3334e74

Please sign in to comment.