Skip to content

Commit 813fb2d

Browse files
authored
Merge pull request #27 from TypeError/fix/performance-improvement-set-headers-v1.0.1
Fix/performance improvement set headers v1.0.1
2 parents 5a5d847 + c28882b commit 813fb2d

File tree

4 files changed

+86
-64
lines changed

4 files changed

+86
-64
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
- Placeholder for upcoming changes.
1111

12+
## [1.0.1] - 2024-10-18
13+
14+
### Fixed
15+
16+
- Improved performance of `Secure.set_headers` by reducing redundant type checks. ([#26](https://github.com/TypeError/secure/issues/26))
17+
1218
## [1.0.0] - 2024-09-27
1319

1420
### Breaking Changes

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "secure"
7-
version = "1.0.0"
7+
version = "1.0.1"
88
description = "A lightweight package that adds security headers for Python web frameworks."
99
readme = { file = "README.md", "content-type" = "text/markdown" }
1010
license = { text = "MIT" }

secure/secure.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -236,22 +236,23 @@ def set_headers(self, response: ResponseProtocol) -> None:
236236
RuntimeError: If an asynchronous 'set_header' method is used in a synchronous context.
237237
AttributeError: If the response object does not support setting headers.
238238
"""
239-
for header_name, header_value in self.headers.items():
240-
if isinstance(response, SetHeaderProtocol):
241-
# If response has set_header method, use it
242-
set_header = response.set_header
243-
if inspect.iscoroutinefunction(set_header):
244-
raise RuntimeError(
245-
"Encountered asynchronous 'set_header' in synchronous context."
246-
)
239+
if isinstance(response, SetHeaderProtocol):
240+
# Use the set_header method if available
241+
set_header = response.set_header
242+
if inspect.iscoroutinefunction(set_header):
243+
raise RuntimeError(
244+
"Encountered asynchronous 'set_header' in synchronous context."
245+
)
246+
for header_name, header_value in self.headers.items():
247247
set_header(header_name, header_value)
248-
elif isinstance(response, HeadersProtocol): # type: ignore
249-
# If response has headers dictionary, use it
248+
elif isinstance(response, HeadersProtocol): # type: ignore
249+
# Use the headers dictionary if available
250+
for header_name, header_value in self.headers.items():
250251
response.headers[header_name] = header_value
251-
else:
252-
raise AttributeError(
253-
f"Response object of type '{type(response).__name__}' does not support setting headers."
254-
)
252+
else:
253+
raise AttributeError(
254+
f"Response object of type '{type(response).__name__}' does not support setting headers."
255+
)
255256

256257
async def set_headers_async(self, response: ResponseProtocol) -> None:
257258
"""
@@ -266,18 +267,20 @@ async def set_headers_async(self, response: ResponseProtocol) -> None:
266267
Raises:
267268
AttributeError: If the response object does not support setting headers.
268269
"""
269-
for header_name, header_value in self.headers.items():
270-
if isinstance(response, SetHeaderProtocol):
271-
# If response has set_header method, use it
272-
set_header = response.set_header
273-
if inspect.iscoroutinefunction(set_header):
270+
if isinstance(response, SetHeaderProtocol):
271+
# Use the set_header method if available
272+
set_header = response.set_header
273+
if inspect.iscoroutinefunction(set_header):
274+
for header_name, header_value in self.headers.items():
274275
await set_header(header_name, header_value)
275-
else:
276+
else:
277+
for header_name, header_value in self.headers.items():
276278
set_header(header_name, header_value)
277-
elif isinstance(response, HeadersProtocol): # type: ignore
278-
# If response has headers dictionary, use it
279+
elif isinstance(response, HeadersProtocol): # type: ignore
280+
# Use the headers dictionary if available
281+
for header_name, header_value in self.headers.items():
279282
response.headers[header_name] = header_value
280-
else:
281-
raise AttributeError(
282-
f"Response object of type '{type(response).__name__}' does not support setting headers."
283-
)
283+
else:
284+
raise AttributeError(
285+
f"Response object of type '{type(response).__name__}' does not support setting headers."
286+
)

tests/secure/test_secure.py

+50-37
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import unittest
23

34
from secure import (
@@ -44,6 +45,20 @@ class MockResponseNoHeaders:
4445

4546

4647
class TestSecure(unittest.TestCase):
48+
def setUp(self):
49+
# Initialize Secure with some test headers
50+
self.secure = Secure(
51+
custom=[
52+
CustomHeader("X-Test-Header-1", "Value1"),
53+
CustomHeader("X-Test-Header-2", "Value2"),
54+
]
55+
)
56+
# Precompute headers dictionary
57+
self.secure.headers = {
58+
header.header_name: header.header_value
59+
for header in self.secure.headers_list
60+
}
61+
4762
def test_with_default_headers(self):
4863
"""Test that default headers are correctly applied."""
4964
secure_headers = Secure.with_default_headers()
@@ -210,8 +225,6 @@ def test_async_set_headers(self):
210225
async def mock_set_headers():
211226
await secure_headers.set_headers_async(response)
212227

213-
import asyncio
214-
215228
asyncio.run(mock_set_headers())
216229

217230
# Verify that headers are set asynchronously
@@ -235,43 +248,43 @@ async def mock_set_headers():
235248

236249
def test_set_headers_with_set_header_method(self):
237250
"""Test setting headers on a response object with set_header method."""
238-
secure_headers = Secure.with_default_headers()
239251
response = MockResponseWithSetHeader()
240-
241-
# Apply the headers to the response object
242-
secure_headers.set_headers(response)
252+
self.secure.set_headers(response)
243253

244254
# Verify that headers are set using set_header method
245-
self.assertIn("Strict-Transport-Security", response.header_storage)
246-
self.assertEqual(
247-
response.header_storage["Strict-Transport-Security"],
248-
"max-age=31536000",
249-
)
255+
self.assertEqual(response.header_storage, self.secure.headers)
256+
# Ensure set_header was called correct number of times
257+
self.assertEqual(len(response.header_storage), len(self.secure.headers))
250258

251-
self.assertIn("X-Content-Type-Options", response.header_storage)
252-
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
259+
def test_set_headers_with_headers_dict(self):
260+
"""Test set_headers with a response object that has a headers dictionary."""
261+
response = MockResponse()
262+
self.secure.set_headers(response)
253263

254-
def test_async_set_headers_with_async_set_header_method(self):
255-
"""Test async setting headers on a response object with async set_header method."""
256-
secure_headers = Secure.with_default_headers()
257-
response = MockResponseAsyncSetHeader()
264+
# Verify that headers are set
265+
self.assertEqual(response.headers, self.secure.headers)
258266

259-
async def mock_set_headers():
260-
await secure_headers.set_headers_async(response)
267+
def test_set_headers_async_with_async_set_header(self):
268+
"""Test set_headers_async with a response object that has an asynchronous set_header method."""
269+
response = MockResponseAsyncSetHeader()
261270

262-
import asyncio
271+
async def test_async():
272+
await self.secure.set_headers_async(response)
263273

264-
asyncio.run(mock_set_headers())
274+
asyncio.run(test_async())
265275

266276
# Verify that headers are set using async set_header method
267-
self.assertIn("Strict-Transport-Security", response.header_storage)
268-
self.assertEqual(
269-
response.header_storage["Strict-Transport-Security"],
270-
"max-age=31536000",
271-
)
277+
self.assertEqual(response.header_storage, self.secure.headers)
278+
# Ensure set_header was called correct number of times
279+
self.assertEqual(len(response.header_storage), len(self.secure.headers))
280+
281+
def test_set_headers_async_with_headers_dict(self):
282+
"""Test set_headers_async with a response object that has a headers dictionary."""
283+
response = MockResponse()
284+
asyncio.run(self.secure.set_headers_async(response))
272285

273-
self.assertIn("X-Content-Type-Options", response.header_storage)
274-
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
286+
# Verify that headers are set
287+
self.assertEqual(response.headers, self.secure.headers)
275288

276289
def test_set_headers_missing_interface(self):
277290
"""Test that an error is raised when response object lacks required methods."""
@@ -286,6 +299,12 @@ def test_set_headers_missing_interface(self):
286299
str(context.exception),
287300
)
288301

302+
def test_set_headers_with_async_set_header_in_sync_context(self):
303+
"""Test set_headers raises RuntimeError when encountering async set_header in sync context."""
304+
response = MockResponseAsyncSetHeader()
305+
with self.assertRaises(RuntimeError):
306+
self.secure.set_headers(response)
307+
289308
def test_set_headers_overwrites_existing_headers(self):
290309
"""Test that existing headers are overwritten by Secure."""
291310
secure_headers = Secure.with_default_headers()
@@ -347,10 +366,10 @@ def test_invalid_preset(self):
347366

348367
def test_empty_secure_instance(self):
349368
"""Test that an empty Secure instance does not set any headers."""
350-
secure_headers = Secure()
369+
self.secure = Secure()
351370
response = MockResponse()
352371

353-
secure_headers.set_headers(response)
372+
self.secure.set_headers(response)
354373
self.assertEqual(len(response.headers), 0)
355374

356375
def test_multiple_custom_headers(self):
@@ -430,16 +449,10 @@ def test_set_headers_async_with_sync_set_header(self):
430449
async def mock_set_headers():
431450
await secure_headers.set_headers_async(response)
432451

433-
import asyncio
434-
435452
asyncio.run(mock_set_headers())
436453

437454
# Verify that headers are set using set_header method
438-
self.assertIn("Strict-Transport-Security", response.header_storage)
439-
self.assertEqual(
440-
response.header_storage["Strict-Transport-Security"],
441-
"max-age=31536000",
442-
)
455+
self.assertEqual(response.header_storage, secure_headers.headers)
443456

444457
def test_set_headers_with_no_headers_or_set_header(self):
445458
"""Test that an error is raised when response lacks both headers and set_header."""

0 commit comments

Comments
 (0)