Skip to content

Commit b5eebcc

Browse files
Merge pull request #106 from microsoft/psl-unit-test-cases
Client Advisor | Backend test cases added
2 parents f612b30 + b4d119b commit b5eebcc

File tree

8 files changed

+1850
-1
lines changed

8 files changed

+1850
-1
lines changed

.github/workflows/test_client_advisor.yml

+17
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,20 @@ jobs:
4747
path: |
4848
ClientAdvisor/App/frontend/coverage/
4949
ClientAdvisor/App/frontend/coverage/lcov-report/
50+
ClientAdvisor/App/htmlcov/
51+
- name: Install Backend Dependencies
52+
run: |
53+
cd ClientAdvisor/App
54+
python -m pip install -r requirements.txt
55+
python -m pip install coverage pytest-cov
56+
- name: Run Backend Tests with Coverage
57+
run: |
58+
cd ClientAdvisor/App
59+
python -m pytest -vv --cov=. --cov-report=xml --cov-report=html --cov-report=term-missing --cov-fail-under=80 --junitxml=coverage-junit.xml
60+
- uses: actions/upload-artifact@v4
61+
with:
62+
name: client-advisor-coverage
63+
path: |
64+
ClientAdvisor/App/coverage.xml
65+
ClientAdvisor/App/coverage-junit.xml
66+
ClientAdvisor/App/htmlcov/

ClientAdvisor/App/requirements-dev.txt

+2
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ httpx==0.27.0
1515
flake8==7.1.1
1616
black==24.8.0
1717
autoflake==2.3.1
18+
isort==5.13.2pytest-asyncio==0.24.0
19+
pytest-cov==5.0.0
1820
isort==5.13.2

ClientAdvisor/App/requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ httpx==0.27.0
1515
flake8==7.1.1
1616
black==24.8.0
1717
autoflake==2.3.1
18-
isort==5.13.2
18+
isort==5.13.2
19+
pytest-asyncio==0.24.0
20+
pytest-cov==5.0.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import base64
2+
import json
3+
from unittest.mock import patch
4+
5+
from backend.auth.auth_utils import (get_authenticated_user_details,
6+
get_tenantid)
7+
8+
9+
def test_get_authenticated_user_details_no_principal_id():
10+
request_headers = {}
11+
sample_user_data = {
12+
"X-Ms-Client-Principal-Id": "default-id",
13+
"X-Ms-Client-Principal-Name": "default-name",
14+
"X-Ms-Client-Principal-Idp": "default-idp",
15+
"X-Ms-Token-Aad-Id-Token": "default-token",
16+
"X-Ms-Client-Principal": "default-b64",
17+
}
18+
with patch("backend.auth.sample_user.sample_user", sample_user_data):
19+
user_details = get_authenticated_user_details(request_headers)
20+
assert user_details["user_principal_id"] == "default-id"
21+
assert user_details["user_name"] == "default-name"
22+
assert user_details["auth_provider"] == "default-idp"
23+
assert user_details["auth_token"] == "default-token"
24+
assert user_details["client_principal_b64"] == "default-b64"
25+
26+
27+
def test_get_authenticated_user_details_with_principal_id():
28+
request_headers = {
29+
"X-Ms-Client-Principal-Id": "test-id",
30+
"X-Ms-Client-Principal-Name": "test-name",
31+
"X-Ms-Client-Principal-Idp": "test-idp",
32+
"X-Ms-Token-Aad-Id-Token": "test-token",
33+
"X-Ms-Client-Principal": "test-b64",
34+
}
35+
user_details = get_authenticated_user_details(request_headers)
36+
assert user_details["user_principal_id"] == "test-id"
37+
assert user_details["user_name"] == "test-name"
38+
assert user_details["auth_provider"] == "test-idp"
39+
assert user_details["auth_token"] == "test-token"
40+
assert user_details["client_principal_b64"] == "test-b64"
41+
42+
43+
def test_get_tenantid_valid_b64():
44+
user_info = {"tid": "test-tenant-id"}
45+
client_principal_b64 = base64.b64encode(
46+
json.dumps(user_info).encode("utf-8")
47+
).decode("utf-8")
48+
tenant_id = get_tenantid(client_principal_b64)
49+
assert tenant_id == "test-tenant-id"
50+
51+
52+
def test_get_tenantid_invalid_b64():
53+
client_principal_b64 = "invalid-b64"
54+
with patch("backend.auth.auth_utils.logging") as mock_logging:
55+
tenant_id = get_tenantid(client_principal_b64)
56+
assert tenant_id == ""
57+
mock_logging.exception.assert_called_once()
58+
59+
60+
def test_get_tenantid_no_tid():
61+
user_info = {"some_other_key": "value"}
62+
client_principal_b64 = base64.b64encode(
63+
json.dumps(user_info).encode("utf-8")
64+
).decode("utf-8")
65+
tenant_id = get_tenantid(client_principal_b64)
66+
assert tenant_id is None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from unittest.mock import AsyncMock, MagicMock, patch
2+
3+
import pytest
4+
from azure.cosmos import exceptions
5+
6+
from backend.history.cosmosdbservice import CosmosConversationClient
7+
8+
9+
# Helper function to create an async iterable
10+
class AsyncIterator:
11+
def __init__(self, items):
12+
self.items = items
13+
self.index = 0
14+
15+
def __aiter__(self):
16+
return self
17+
18+
async def __anext__(self):
19+
if self.index < len(self.items):
20+
item = self.items[self.index]
21+
self.index += 1
22+
return item
23+
else:
24+
raise StopAsyncIteration
25+
26+
27+
@pytest.fixture
28+
def cosmos_client():
29+
return CosmosConversationClient(
30+
cosmosdb_endpoint="https://fake.endpoint",
31+
credential="fake_credential",
32+
database_name="test_db",
33+
container_name="test_container",
34+
)
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_init_invalid_credentials():
39+
with patch(
40+
"azure.cosmos.aio.CosmosClient.__init__",
41+
side_effect=exceptions.CosmosHttpResponseError(
42+
status_code=401, message="Unauthorized"
43+
),
44+
):
45+
with pytest.raises(ValueError, match="Invalid credentials"):
46+
CosmosConversationClient(
47+
cosmosdb_endpoint="https://fake.endpoint",
48+
credential="fake_credential",
49+
database_name="test_db",
50+
container_name="test_container",
51+
)
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_init_invalid_endpoint():
56+
with patch(
57+
"azure.cosmos.aio.CosmosClient.__init__",
58+
side_effect=exceptions.CosmosHttpResponseError(
59+
status_code=404, message="Not Found"
60+
),
61+
):
62+
with pytest.raises(ValueError, match="Invalid CosmosDB endpoint"):
63+
CosmosConversationClient(
64+
cosmosdb_endpoint="https://fake.endpoint",
65+
credential="fake_credential",
66+
database_name="test_db",
67+
container_name="test_container",
68+
)
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_ensure_success(cosmos_client):
73+
cosmos_client.database_client.read = AsyncMock()
74+
cosmos_client.container_client.read = AsyncMock()
75+
success, message = await cosmos_client.ensure()
76+
assert success
77+
assert message == "CosmosDB client initialized successfully"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_ensure_failure(cosmos_client):
82+
cosmos_client.database_client.read = AsyncMock(side_effect=Exception)
83+
success, message = await cosmos_client.ensure()
84+
assert not success
85+
assert "CosmosDB database" in message
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_create_conversation(cosmos_client):
90+
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "123"})
91+
response = await cosmos_client.create_conversation("user_1", "Test Conversation")
92+
assert response["id"] == "123"
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_create_conversation_failure(cosmos_client):
97+
cosmos_client.container_client.upsert_item = AsyncMock(return_value=None)
98+
response = await cosmos_client.create_conversation("user_1", "Test Conversation")
99+
assert not response
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_upsert_conversation(cosmos_client):
104+
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "123"})
105+
response = await cosmos_client.upsert_conversation({"id": "123"})
106+
assert response["id"] == "123"
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_delete_conversation(cosmos_client):
111+
cosmos_client.container_client.read_item = AsyncMock(return_value={"id": "123"})
112+
cosmos_client.container_client.delete_item = AsyncMock(return_value=True)
113+
response = await cosmos_client.delete_conversation("user_1", "123")
114+
assert response
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_delete_conversation_not_found(cosmos_client):
119+
cosmos_client.container_client.read_item = AsyncMock(return_value=None)
120+
response = await cosmos_client.delete_conversation("user_1", "123")
121+
assert response
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_delete_messages(cosmos_client):
126+
cosmos_client.get_messages = AsyncMock(
127+
return_value=[{"id": "msg_1"}, {"id": "msg_2"}]
128+
)
129+
cosmos_client.container_client.delete_item = AsyncMock(return_value=True)
130+
response = await cosmos_client.delete_messages("conv_1", "user_1")
131+
assert len(response) == 2
132+
133+
134+
@pytest.mark.asyncio
135+
async def test_get_conversations(cosmos_client):
136+
items = [{"id": "conv_1"}, {"id": "conv_2"}]
137+
cosmos_client.container_client.query_items = MagicMock(
138+
return_value=AsyncIterator(items)
139+
)
140+
response = await cosmos_client.get_conversations("user_1", 10)
141+
assert len(response) == 2
142+
assert response[0]["id"] == "conv_1"
143+
assert response[1]["id"] == "conv_2"
144+
145+
146+
@pytest.mark.asyncio
147+
async def test_get_conversation(cosmos_client):
148+
items = [{"id": "conv_1"}]
149+
cosmos_client.container_client.query_items = MagicMock(
150+
return_value=AsyncIterator(items)
151+
)
152+
response = await cosmos_client.get_conversation("user_1", "conv_1")
153+
assert response["id"] == "conv_1"
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_create_message(cosmos_client):
158+
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "msg_1"})
159+
cosmos_client.get_conversation = AsyncMock(return_value={"id": "conv_1"})
160+
cosmos_client.upsert_conversation = AsyncMock()
161+
response = await cosmos_client.create_message(
162+
"msg_1", "conv_1", "user_1", {"role": "user", "content": "Hello"}
163+
)
164+
assert response["id"] == "msg_1"
165+
166+
167+
@pytest.mark.asyncio
168+
async def test_update_message_feedback(cosmos_client):
169+
cosmos_client.container_client.read_item = AsyncMock(return_value={"id": "msg_1"})
170+
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "msg_1"})
171+
response = await cosmos_client.update_message_feedback(
172+
"user_1", "msg_1", "positive"
173+
)
174+
assert response["id"] == "msg_1"
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_get_messages(cosmos_client):
179+
items = [{"id": "msg_1"}, {"id": "msg_2"}]
180+
cosmos_client.container_client.query_items = MagicMock(
181+
return_value=AsyncIterator(items)
182+
)
183+
response = await cosmos_client.get_messages("user_1", "conv_1")
184+
assert len(response) == 2

0 commit comments

Comments
 (0)