diff --git a/.github/workflows/ci-checks.yaml b/.github/workflows/ci-checks.yaml
index e8c59ff..3e244d5 100644
--- a/.github/workflows/ci-checks.yaml
+++ b/.github/workflows/ci-checks.yaml
@@ -36,3 +36,29 @@ jobs:
- name: Run pre-commit
run: uv run pre-commit run --all-files
+
+ job-image-processing-unit-tests:
+ name: Image Processing Unit Tests
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set up Python
+ uses: actions/setup-python@v3
+ with:
+ python-version: ${{ env.MIN_PYTHON_VERSION }}
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v4
+ with:
+ enable-cache: true
+
+ - name: Install the project
+ run: uv sync
+ working-directory: image_processing
+
+ - name: Run PyTest
+ run: uv run pytest --cov=. --cov-config=.coveragerc
+ working-directory: image_processing
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9610159..563d806 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -18,6 +18,7 @@ repos:
# Python checks
- id: name-tests-test
+ args: [--pytest-test-first]
# JSON files
- id: pretty-format-json
diff --git a/deploy_ai_search_indexes/src/deploy_ai_search_indexes/ai_search.py b/deploy_ai_search_indexes/src/deploy_ai_search_indexes/ai_search.py
index 44be7f0..8a872f8 100644
--- a/deploy_ai_search_indexes/src/deploy_ai_search_indexes/ai_search.py
+++ b/deploy_ai_search_indexes/src/deploy_ai_search_indexes/ai_search.py
@@ -219,7 +219,11 @@ def get_mark_up_cleaner_skill(self, chunk_by_page: False) -> WebApiSkill:
mark_up_cleaner_context = "/document/page_wise_layout/*"
inputs = [
InputFieldMappingEntry(
- name="chunk", source="/document/page_wise_layout/*/merged_content"
+ name="mark_up", source="/document/page_wise_layout/*/merged_content"
+ ),
+ InputFieldMappingEntry(
+ name="page_number",
+ source="/document/page_wise_layout/*/page_number",
),
InputFieldMappingEntry(
name="figures",
@@ -230,7 +234,10 @@ def get_mark_up_cleaner_skill(self, chunk_by_page: False) -> WebApiSkill:
mark_up_cleaner_context = "/document/chunk_mark_ups/*"
inputs = [
InputFieldMappingEntry(
- name="chunk", source="/document/chunk_mark_ups/*"
+ name="mark_up", source="/document/chunk_mark_ups/*/mark_up"
+ ),
+ InputFieldMappingEntry(
+ name="page_number", source="/document/chunk_mark_ups/*/page_number"
),
InputFieldMappingEntry(
name="figures", source="/document/layout/figures/*/updated_figure"
@@ -238,12 +245,15 @@ def get_mark_up_cleaner_skill(self, chunk_by_page: False) -> WebApiSkill:
]
mark_up_cleaner_skill_outputs = [
- OutputFieldMappingEntry(name="chunk_cleaned", target_name="chunk_cleaned"),
OutputFieldMappingEntry(
- name="chunk_sections", target_name="chunk_sections"
+ name="cleaned_text", target_name="final_cleaned_text"
+ ),
+ OutputFieldMappingEntry(name="sections", target_name="final_sections"),
+ OutputFieldMappingEntry(name="mark_up", target_name="final_mark_up"),
+ OutputFieldMappingEntry(name="figures", target_name="final_chunk_figures"),
+ OutputFieldMappingEntry(
+ name="page_number", target_name="final_page_number"
),
- OutputFieldMappingEntry(name="chunk_mark_up", target_name="chunk_mark_up"),
- OutputFieldMappingEntry(name="chunk_figures", target_name="chunk_figures"),
]
mark_up_cleaner_skill = WebApiSkill(
@@ -302,7 +312,11 @@ def get_semantic_chunker_skill(
semantic_text_chunker_skill_inputs = [
InputFieldMappingEntry(
name="content", source="/document/layout_merged_content"
- )
+ ),
+ InputFieldMappingEntry(
+ name="per_page_starting_sentences",
+ source="/document/per_page_starting_sentences",
+ ),
]
semantic_text_chunker_skill_outputs = [
@@ -368,7 +382,13 @@ def get_layout_analysis_skill(
)
]
else:
- output = [OutputFieldMappingEntry(name="layout", target_name="layout")]
+ output = [
+ OutputFieldMappingEntry(name="layout", target_name="layout"),
+ OutputFieldMappingEntry(
+ name="per_page_starting_sentences",
+ target_name="per_page_starting_sentences",
+ ),
+ ]
layout_analysis_skill = WebApiSkill(
name="Layout Analysis Skill",
diff --git a/deploy_ai_search_indexes/src/deploy_ai_search_indexes/image_processing.py b/deploy_ai_search_indexes/src/deploy_ai_search_indexes/image_processing.py
index eb11fba..b1f875b 100644
--- a/deploy_ai_search_indexes/src/deploy_ai_search_indexes/image_processing.py
+++ b/deploy_ai_search_indexes/src/deploy_ai_search_indexes/image_processing.py
@@ -81,6 +81,13 @@ def get_index_fields(self) -> list[SearchableField]:
type=SearchFieldDataType.String,
collection=True,
),
+ SimpleField(
+ name="PageNumber",
+ type=SearchFieldDataType.Int64,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ ),
SearchField(
name="ChunkEmbedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
@@ -137,19 +144,6 @@ def get_index_fields(self) -> list[SearchableField]:
),
]
- if self.enable_page_by_chunking:
- fields.extend(
- [
- SimpleField(
- name="PageNumber",
- type=SearchFieldDataType.Int64,
- sortable=True,
- filterable=True,
- facetable=True,
- )
- ]
- )
-
return fields
def get_semantic_search(self) -> SemanticSearch:
@@ -194,11 +188,12 @@ def get_skills(self) -> list:
if self.enable_page_by_chunking:
embedding_skill = self.get_vector_skill(
"/document/page_wise_layout/*",
- "/document/page_wise_layout/*/chunk_cleaned",
+ "/document/page_wise_layout/*/final_cleaned_text",
)
else:
embedding_skill = self.get_vector_skill(
- "/document/chunk_mark_ups/*", "/document/chunk_mark_ups/*/chunk_cleaned"
+ "/document/chunk_mark_ups/*",
+ "/document/chunk_mark_ups/*/final_cleaned_text",
)
if self.enable_page_by_chunking:
@@ -229,7 +224,7 @@ def get_index_projections(self) -> SearchIndexerIndexProjection:
source_context = "/document/page_wise_layout/*"
mappings = [
InputFieldMappingEntry(
- name="Chunk", source="/document/page_wise_layout/*/chunk_mark_up"
+ name="Chunk", source="/document/page_wise_layout/*/final_mark_up"
),
InputFieldMappingEntry(
name="ChunkEmbedding",
@@ -239,24 +234,25 @@ def get_index_projections(self) -> SearchIndexerIndexProjection:
InputFieldMappingEntry(name="SourceUri", source="/document/SourceUri"),
InputFieldMappingEntry(
name="Sections",
- source="/document/page_wise_layout/*/chunk_sections",
+ source="/document/page_wise_layout/*/final_sections",
),
InputFieldMappingEntry(
name="ChunkFigures",
- source="/document/page_wise_layout/*/chunk_figures/*",
+ source="/document/page_wise_layout/*/final_chunk_figures/*",
),
InputFieldMappingEntry(
name="DateLastModified", source="/document/DateLastModified"
),
InputFieldMappingEntry(
- name="PageNumber", source="/document/page_wise_layout/*/page_number"
+ name="PageNumber",
+ source="/document/page_wise_layout/*/final_page_number",
),
]
else:
source_context = "/document/chunk_mark_ups/*"
mappings = [
InputFieldMappingEntry(
- name="Chunk", source="/document/chunk_mark_ups/*/chunk_mark_up"
+ name="Chunk", source="/document/chunk_mark_ups/*/final_mark_up"
),
InputFieldMappingEntry(
name="ChunkEmbedding",
@@ -265,15 +261,19 @@ def get_index_projections(self) -> SearchIndexerIndexProjection:
InputFieldMappingEntry(name="Title", source="/document/Title"),
InputFieldMappingEntry(name="SourceUri", source="/document/SourceUri"),
InputFieldMappingEntry(
- name="Sections", source="/document/chunk_mark_ups/*/chunk_sections"
+ name="Sections", source="/document/chunk_mark_ups/*/final_sections"
),
InputFieldMappingEntry(
name="ChunkFigures",
- source="/document/chunk_mark_ups/*/chunk_figures/*",
+ source="/document/chunk_mark_ups/*/final_chunk_figures/*",
),
InputFieldMappingEntry(
name="DateLastModified", source="/document/DateLastModified"
),
+ InputFieldMappingEntry(
+ name="PageNumber",
+ source="/document/chunk_mark_ups/*/final_page_number",
+ ),
]
index_projections = SearchIndexerIndexProjection(
diff --git a/image_processing/.coveragerc b/image_processing/.coveragerc
new file mode 100644
index 0000000..50cceb1
--- /dev/null
+++ b/image_processing/.coveragerc
@@ -0,0 +1,11 @@
+[run]
+omit =
+ tests/*
+ */__init__.py
+
+[report]
+omit =
+ tests/*
+ */__init__.py
+exclude_lines =
+ if __name__ == "__main__":
diff --git a/image_processing/pyproject.toml b/image_processing/pyproject.toml
index c7b082e..6b153d1 100644
--- a/image_processing/pyproject.toml
+++ b/image_processing/pyproject.toml
@@ -43,4 +43,9 @@ dev = [
"pygments>=2.18.0",
"ruff>=0.8.1",
"python-dotenv>=1.0.1",
+ "coverage>=7.6.12",
+ "pytest>=8.3.4",
+ "pytest-asyncio>=0.25.3",
+ "pytest-cov>=6.0.0",
+ "pytest-mock>=3.14.0",
]
diff --git a/image_processing/pytest.ini b/image_processing/pytest.ini
new file mode 100644
index 0000000..84624a0
--- /dev/null
+++ b/image_processing/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+pythonpath = src/image_processing
diff --git a/image_processing/src/image_processing/__init__.py b/image_processing/src/image_processing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/image_processing/src/image_processing/layout_analysis.py b/image_processing/src/image_processing/layout_analysis.py
index 081b76f..5a1ef4f 100644
--- a/image_processing/src/image_processing/layout_analysis.py
+++ b/image_processing/src/image_processing/layout_analysis.py
@@ -22,6 +22,7 @@
LayoutHolder,
PageWiseContentHolder,
NonPageWiseContentHolder,
+ PerPageStartingSentenceHolder,
)
@@ -340,6 +341,40 @@ def create_page_wise_content(self) -> list[LayoutHolder]:
return page_wise_contents
+ def create_per_page_starting_sentence(self) -> list[PerPageStartingSentenceHolder]:
+ """Create a list of the starting sentence of each page so we can assign the starting sentence to the page number.
+
+ Returns:
+ --------
+ list: A list of the starting sentence of each page."""
+
+ per_page_starting_sentences = []
+
+ for page in self.result.pages:
+ page_content = self.result.content[
+ page.spans[0]["offset"] : page.spans[0]["offset"]
+ + page.spans[0]["length"]
+ ]
+
+ # Remove any leading whitespace/newlines.
+ cleaned_content = page_content.lstrip()
+ # If a newline appears before a period, split on newline; otherwise, on period.
+ if "\n" in cleaned_content:
+ first_line = cleaned_content.split("\n", 1)[0]
+ elif "." in cleaned_content:
+ first_line = cleaned_content.split(".", 1)[0]
+ else:
+ first_line = cleaned_content
+
+ per_page_starting_sentences.append(
+ PerPageStartingSentenceHolder(
+ page_number=page.page_number,
+ starting_sentence=first_line.strip(),
+ )
+ )
+
+ return per_page_starting_sentences
+
async def get_document_intelligence_client(self) -> DocumentIntelligenceClient:
"""Get the Azure Document Intelligence client.
@@ -487,7 +522,12 @@ async def analyse(self):
if self.extract_figures:
await self.process_figures_from_extracted_content(text_content)
- output_record = NonPageWiseContentHolder(layout=text_content)
+ per_page_starting_sentences = self.create_per_page_starting_sentence()
+
+ output_record = NonPageWiseContentHolder(
+ layout=text_content,
+ per_page_starting_sentences=per_page_starting_sentences,
+ )
except Exception as e:
logging.error(e)
diff --git a/image_processing/src/image_processing/layout_holders.py b/image_processing/src/image_processing/layout_holders.py
index 08d1ab3..8d1535f 100644
--- a/image_processing/src/image_processing/layout_holders.py
+++ b/image_processing/src/image_processing/layout_holders.py
@@ -6,7 +6,6 @@
class FigureHolder(BaseModel):
-
"""A class to hold the figure extracted from the document."""
figure_id: str = Field(..., alias="FigureId")
@@ -48,7 +47,28 @@ class PageWiseContentHolder(BaseModel):
page_wise_layout: list[LayoutHolder]
+class PerPageStartingSentenceHolder(BaseModel):
+ """A class to hold the starting sentence of each page."""
+
+ page_number: int
+ starting_sentence: str
+
+
class NonPageWiseContentHolder(BaseModel):
"""A class to hold the non-page-wise content extracted from the document."""
layout: LayoutHolder
+ per_page_starting_sentences: list[PerPageStartingSentenceHolder] = Field(
+ default_factory=list
+ )
+
+
+class ChunkHolder(BaseModel):
+ """A class to hold the text extracted from the document after it has been chunked."""
+
+ mark_up: str
+ sections: Optional[list[str]] = Field(default_factory=list)
+ figures: Optional[list[FigureHolder]] = Field(default_factory=list)
+ starting_sentence: Optional[str] = None
+ cleaned_text: Optional[str] = None
+ page_number: Optional[int] = Field(default=None)
diff --git a/image_processing/src/image_processing/mark_up_cleaner.py b/image_processing/src/image_processing/mark_up_cleaner.py
index 30a5813..3daac77 100644
--- a/image_processing/src/image_processing/mark_up_cleaner.py
+++ b/image_processing/src/image_processing/mark_up_cleaner.py
@@ -3,7 +3,7 @@
import logging
import json
import regex as re
-from layout_holders import FigureHolder
+from layout_holders import FigureHolder, ChunkHolder
class MarkUpCleaner:
@@ -18,8 +18,8 @@ def get_sections(self, text) -> list:
list: The sections related to text
"""
# Updated regex pattern to capture markdown headers like ### Header
- combined_pattern = r"(?<=\n|^)[#]+\s*(.*?)(?=\n)"
- doc_metadata = re.findall(combined_pattern, text, re.DOTALL)
+ combined_pattern = r"^\s*[#]+\s*(.*?)(?=\n|$)"
+ doc_metadata = re.findall(combined_pattern, text, re.MULTILINE)
return self.clean_sections(doc_metadata)
def get_figure_ids(self, text: str) -> list:
@@ -61,12 +61,14 @@ def remove_markdown_tags(self, text: str, tag_patterns: dict) -> str:
for tag, pattern in tag_patterns.items():
try:
# Replace the tags using the specific pattern, keeping the content inside the tags
- if tag == "header":
+ if tag in ["header", "figure"]:
text = re.sub(
pattern, r"\2", text, flags=re.DOTALL | re.MULTILINE
)
else:
- text = re.sub(pattern, r"\1", text, flags=re.DOTALL)
+ text = re.sub(
+ pattern, r"\1", text, flags=re.DOTALL | re.MULTILINE
+ )
except re.error as e:
logging.error(f"Regex error for tag '{tag}': {e}")
except Exception as e:
@@ -74,7 +76,7 @@ def remove_markdown_tags(self, text: str, tag_patterns: dict) -> str:
return text
def clean_text_and_extract_metadata(
- self, text: str, figures: list[FigureHolder]
+ self, chunk: ChunkHolder, figures: list[FigureHolder]
) -> tuple[str, str]:
"""This function performs following cleanup activities on the text, remove all unicode characters
remove line spacing,remove stop words, normalize characters
@@ -86,36 +88,39 @@ def clean_text_and_extract_metadata(
Returns:
str: The clean text."""
- return_record = {}
-
try:
- logging.info(f"Input text: {text}")
- if len(text) == 0:
+ logging.info(f"Input text: {chunk.mark_up}")
+ if len(chunk.mark_up) == 0:
logging.error("Input text is empty")
raise ValueError("Input text is empty")
- return_record["chunk_mark_up"] = text
-
- figure_ids = self.get_figure_ids(text)
+ figure_ids = self.get_figure_ids(chunk.mark_up)
- return_record["chunk_sections"] = self.get_sections(text)
- return_record["chunk_figures"] = [
- figure.model_dump(by_alias=True)
- for figure in figures
- if figure.figure_id in figure_ids
+ chunk.sections = self.get_sections(chunk.mark_up)
+ chunk.figures = [
+ figure for figure in figures if figure.figure_id in figure_ids
]
- logging.info(f"Sections: {return_record['chunk_sections']}")
+ logging.info(f"Sections: {chunk.sections}")
+
+ # Check if the chunk contains only figure tags (plus whitespace).
+ figure_tag_pattern = (
+ r"(.*?)"
+ )
+ text_without_figures = re.sub(figure_tag_pattern, "", chunk.mark_up).strip()
+ if not text_without_figures and chunk.figures:
+ # When no text outside of figure tags is present, set page_number from the first figure.
+ chunk.page_number = chunk.figures[0].page_number
# Define specific patterns for each tag
tag_patterns = {
"figurecontent": r"",
- "figure": r"(.*?)",
+ "figure": r"(.*?)",
"figures": r"\(figures/\d+\)(.*?)\(figures/\d+\)",
"figcaption": r"(.*?)",
"header": r"^\s*(#{1,6})\s*(.*?)\s*$",
}
- cleaned_text = self.remove_markdown_tags(text, tag_patterns)
+ cleaned_text = self.remove_markdown_tags(chunk.mark_up, tag_patterns)
logging.info(f"Removed markdown tags: {cleaned_text}")
@@ -128,11 +133,11 @@ def clean_text_and_extract_metadata(
logging.error("Cleaned text is empty")
raise ValueError("Cleaned text is empty")
else:
- return_record["chunk_cleaned"] = cleaned_text
+ chunk.cleaned_text = cleaned_text
except Exception as e:
logging.error(f"An error occurred in clean_text_and_extract_metadata: {e}")
- return ""
- return return_record
+ raise e
+ return chunk.model_dump(by_alias=True)
async def clean(self, record: dict) -> dict:
"""Cleanup the data using standard python libraries.
@@ -157,12 +162,17 @@ async def clean(self, record: dict) -> dict:
figures = [FigureHolder(**figure) for figure in record["data"]["figures"]]
+ chunk_holder = ChunkHolder(mark_up=record["data"]["mark_up"])
+
+ if "page_number" in record["data"]:
+ chunk_holder.page_number = record["data"]["page_number"]
+
cleaned_record["data"] = self.clean_text_and_extract_metadata(
- record["data"]["chunk"], figures
+ chunk_holder, figures
)
except Exception as e:
- logging.error("string cleanup Error: %s", e)
+ logging.error("Cleanup Error: %s", e)
return {
"recordId": record["recordId"],
"data": None,
diff --git a/image_processing/src/image_processing/requirements.txt b/image_processing/src/image_processing/requirements.txt
index b755870..519759b 100644
--- a/image_processing/src/image_processing/requirements.txt
+++ b/image_processing/src/image_processing/requirements.txt
@@ -1,6 +1,6 @@
# This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-editable --no-sources --no-group dev --directory image_processing -o src/image_processing/requirements.txt
-aiohappyeyeballs==2.4.4
+aiohappyeyeballs==2.4.6
aiohttp==3.11.12
aiosignal==1.3.2
annotated-types==0.7.0
@@ -12,7 +12,7 @@ azure-ai-vision-imageanalysis==1.0.0
azure-common==1.1.28
azure-core==1.32.0
azure-functions==1.21.3
-azure-identity==1.19.0
+azure-identity==1.20.0
azure-search==1.0.0b2
azure-search-documents==11.6.0b8
azure-storage-blob==12.24.1
@@ -27,7 +27,7 @@ click==8.1.8
cloudpathlib==0.20.0
colorama==0.4.6 ; sys_platform == 'win32'
confection==0.1.5
-cryptography==44.0.0
+cryptography==44.0.1
cymem==2.0.11
distro==1.9.0
en-core-web-md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1.tar.gz
@@ -38,7 +38,7 @@ fsspec==2025.2.0
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
-huggingface-hub==0.28.1
+huggingface-hub==0.29.0
idna==3.10
isodate==0.7.2
jinja2==3.1.5
@@ -50,7 +50,7 @@ marisa-trie==1.2.1
markdown-it-py==3.0.0
markupsafe==3.0.2
mdurl==0.1.2
-model2vec==0.3.9
+model2vec==0.4.0
msal==1.31.1
msal-extensions==1.2.0
msrest==0.7.1
@@ -58,7 +58,7 @@ multidict==6.1.0
murmurhash==1.0.12
numpy==1.26.4
oauthlib==3.2.2
-openai==1.61.1
+openai==1.63.2
openpyxl==3.1.5
packaging==24.2
pandas==2.2.3
@@ -94,7 +94,7 @@ spacy-loggers==1.0.5
srsly==2.5.1
tenacity==9.0.0
thinc==8.2.5
-tiktoken==0.8.0
+tiktoken==0.9.0
tokenizers==0.21.0
tqdm==4.67.1
typer==0.15.1
diff --git a/image_processing/src/image_processing/semantic_text_chunker.py b/image_processing/src/image_processing/semantic_text_chunker.py
index 5a2c5b6..b97c667 100644
--- a/image_processing/src/image_processing/semantic_text_chunker.py
+++ b/image_processing/src/image_processing/semantic_text_chunker.py
@@ -7,6 +7,7 @@
import spacy
import numpy as np
from model2vec import StaticModel
+from layout_holders import PerPageStartingSentenceHolder, ChunkHolder
class SemanticTextChunker:
@@ -75,7 +76,7 @@ def clean_chunks_and_map(self, chunks, is_table_or_figure_map):
return cleaned_chunks, cleaned_is_table_or_figure_map
- async def chunk(self, text: str) -> list[dict]:
+ async def chunk(self, text: str) -> list[ChunkHolder]:
"""Attempts to chunk the text by:
Splitting into sentences
Grouping sentences that contain figures and tables
@@ -128,7 +129,7 @@ async def chunk(self, text: str) -> list[dict]:
for chunk in reversed_backwards_pass_chunks:
stripped_chunk = chunk.strip()
if len(stripped_chunk) > 0:
- cleaned_final_chunks.append(stripped_chunk)
+ cleaned_final_chunks.append(ChunkHolder(mark_up=stripped_chunk))
logging.info(f"Number of final chunks: {len(cleaned_final_chunks)}")
logging.info(f"Chunks: {cleaned_final_chunks}")
@@ -491,6 +492,34 @@ def sentence_similarity(self, text_1, text_2):
)
return similarity
+ def assign_page_number_to_chunks(
+ self,
+ chunks: list[ChunkHolder],
+ per_page_starting_sentences: list[PerPageStartingSentenceHolder],
+ ) -> list[ChunkHolder]:
+ """Assigns page numbers to the chunks based on the starting sentences of each page.
+
+ Args:
+ chunks (list[ChunkHolder]): The list of chunks.
+ per_page_starting_sentences (list[PerPageStartingSentenceHolder]): The list of starting sentences of each page.
+
+ Returns:
+ list[ChunkHolder]: The list of chunks with page numbers assigned."""
+ page_number = 1
+ for chunk in chunks:
+ for per_page_starting_sentence in per_page_starting_sentences[
+ page_number - 1 :
+ ]:
+ if per_page_starting_sentence.starting_sentence in chunk:
+ logging.info(
+ "Assigning page number %i to chunk",
+ per_page_starting_sentence.page_number,
+ )
+ page_number = per_page_starting_sentence.page_number
+ break
+ chunk.page_number = page_number
+ return chunks
+
async def process_semantic_text_chunker(record: dict, text_chunker) -> dict:
"""Chunk the data.
@@ -514,9 +543,23 @@ async def process_semantic_text_chunker(record: dict, text_chunker) -> dict:
}
# scenarios when page by chunking is enabled
- cleaned_record["data"]["chunks"] = await text_chunker.chunk(
- record["data"]["content"]
- )
+ chunks = await text_chunker.chunk(record["data"]["content"])
+
+ if "per_page_starting_sentences" in record["data"]:
+ per_page_starting_sentences = [
+ PerPageStartingSentenceHolder(**sentence)
+ for sentence in record["data"]["per_page_starting_sentences"]
+ ]
+
+ logging.info(f"Per page starting sentences: {per_page_starting_sentences}")
+
+ chunks = text_chunker.assign_page_number_to_chunks(
+ chunks, per_page_starting_sentences
+ )
+
+ cleaned_record["data"]["chunks"] = [
+ chunk.model_dump(by_alias=True) for chunk in chunks
+ ]
except Exception as e:
logging.error("Chunking Error: %s", e)
diff --git a/image_processing/tests/image_processing/__init__.py b/image_processing/tests/image_processing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/image_processing/tests/image_processing/test_figure_analysis.py b/image_processing/tests/image_processing/test_figure_analysis.py
new file mode 100644
index 0000000..9c1d58f
--- /dev/null
+++ b/image_processing/tests/image_processing/test_figure_analysis.py
@@ -0,0 +1,298 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import pytest
+import base64
+import io
+from PIL import Image
+from unittest.mock import AsyncMock, MagicMock
+from tenacity import RetryError
+from openai import OpenAIError, RateLimitError
+from figure_analysis import FigureAnalysis
+from layout_holders import FigureHolder
+from httpx import Response, Request
+
+# ------------------------
+# Fixtures for Image Data
+# ------------------------
+
+
+@pytest.fixture
+def image_data_100x100():
+ """Return a base64-encoded PNG image of size 100x100."""
+ img = Image.new("RGB", (100, 100), color="red")
+ buffer = io.BytesIO()
+ img.save(buffer, format="PNG")
+ data = buffer.getvalue()
+ return base64.b64encode(data).decode("utf-8")
+
+
+@pytest.fixture
+def image_data_50x50():
+ """Return a base64-encoded PNG image of size 50x50 (small image)."""
+ img = Image.new("RGB", (50, 50), color="blue")
+ buffer = io.BytesIO()
+ img.save(buffer, format="PNG")
+ data = buffer.getvalue()
+ return base64.b64encode(data).decode("utf-8")
+
+
+# ------------------------
+# Fixtures for FigureHolder
+# ------------------------
+
+
+@pytest.fixture
+def valid_figure(image_data_100x100):
+ """
+ A valid figure with sufficient size.
+ Example: FigureHolder(figure_id='12345', description="Figure 1", uri="https://example.com/12345.png", offset=50, length=17)
+ """
+ return FigureHolder(
+ figure_id="12345",
+ description="Figure 1",
+ uri="https://example.com/12345.png",
+ offset=50,
+ length=17,
+ data=image_data_100x100,
+ )
+
+
+@pytest.fixture
+def small_figure(image_data_50x50):
+ """A figure whose image is too small (both dimensions below 75)."""
+ return FigureHolder(
+ figure_id="small1",
+ description="",
+ uri="https://example.com/small1.png",
+ offset=0,
+ length=10,
+ data=image_data_50x50,
+ )
+
+
+# ------------------------
+# Tests for get_image_size
+# ------------------------
+
+
+def test_get_image_size(valid_figure):
+ analysis = FigureAnalysis()
+ width, height = analysis.get_image_size(valid_figure)
+ assert width == 100
+ assert height == 100
+
+
+def test_get_image_size_small(small_figure):
+ analysis = FigureAnalysis()
+ width, height = analysis.get_image_size(small_figure)
+ assert width == 50
+ assert height == 50
+
+
+# ------------------------
+# Tests for understand_image_with_gptv
+# ------------------------
+
+
+@pytest.mark.asyncio
+async def test_understand_image_with_gptv_small(small_figure):
+ """
+ If both width and height are below 75, the image should be considered too small,
+ and its description set to "Irrelevant Image".
+ """
+ analysis = FigureAnalysis()
+ updated_figure = await analysis.understand_image_with_gptv(small_figure)
+ assert updated_figure.description == "Irrelevant Image"
+
+
+@pytest.mark.asyncio
+async def test_understand_image_with_gptv_success(valid_figure, monkeypatch):
+ """
+ Test the success branch of understand_image_with_gptv.
+ Patch AsyncAzureOpenAI to simulate a successful response.
+ """
+ analysis = FigureAnalysis()
+
+ # Set up required environment variables.
+ monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
+ monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
+ monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
+
+ # Create a dummy response object to mimic the client's response.
+ dummy_response = MagicMock()
+ dummy_choice = MagicMock()
+ dummy_message = MagicMock()
+ dummy_message.content = "Generated image description"
+ dummy_choice.message = dummy_message
+ dummy_response.choices = [dummy_choice]
+
+ # Create a dummy async client whose chat.completions.create returns dummy_response.
+ dummy_client = AsyncMock()
+ dummy_client.chat.completions.create.return_value = dummy_response
+
+ # Create a dummy async context manager that returns dummy_client.
+ dummy_async_context = AsyncMock()
+ dummy_async_context.__aenter__.return_value = dummy_client
+
+ # Patch AsyncAzureOpenAI so that instantiating it returns our dummy context manager.
+ monkeypatch.setattr(
+ "figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
+ )
+
+ # Call the function and verify the description is set from the dummy response.
+ updated_figure = await analysis.understand_image_with_gptv(valid_figure)
+ assert updated_figure.description == "Generated image description"
+
+ # Now simulate the case when the API returns an empty description.
+ dummy_message.content = ""
+ updated_figure = await analysis.understand_image_with_gptv(valid_figure)
+ assert updated_figure.description == "Irrelevant Image"
+
+
+@pytest.mark.asyncio
+async def test_understand_image_with_gptv_policy_violation(valid_figure, monkeypatch):
+ """
+ If the OpenAI API raises an error with "ResponsibleAIPolicyViolation" in its message,
+ the description should be set to "Irrelevant Image".
+ """
+ analysis = FigureAnalysis()
+ monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
+ monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
+ monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
+
+ # Define a dummy exception that mimics an OpenAI error with a ResponsibleAIPolicyViolation message.
+ class DummyOpenAIError(OpenAIError):
+ def __init__(self, message):
+ self.message = message
+
+ async def dummy_create(*args, **kwargs):
+ raise DummyOpenAIError("Error: ResponsibleAIPolicyViolation occurred")
+
+ dummy_client = AsyncMock()
+ dummy_client.chat.completions.create.side_effect = dummy_create
+ dummy_async_context = AsyncMock()
+ dummy_async_context.__aenter__.return_value = dummy_client
+ monkeypatch.setattr(
+ "figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
+ )
+
+ updated_figure = await analysis.understand_image_with_gptv(valid_figure)
+ assert updated_figure.description == "Irrelevant Image"
+
+
+@pytest.mark.asyncio
+async def test_understand_image_with_gptv_general_error(valid_figure, monkeypatch):
+ """
+ If the OpenAI API raises an error that does not include "ResponsibleAIPolicyViolation",
+ the error should propagate.
+ """
+ analysis = FigureAnalysis()
+ monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
+ monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
+ monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
+
+ class DummyOpenAIError(OpenAIError):
+ def __init__(self, message):
+ self.message = message
+
+ async def dummy_create(*args, **kwargs):
+ raise DummyOpenAIError("Some other error")
+
+ dummy_client = AsyncMock()
+ dummy_client.chat.completions.create.side_effect = dummy_create
+ dummy_async_context = AsyncMock()
+ dummy_async_context.__aenter__.return_value = dummy_client
+ monkeypatch.setattr(
+ "figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
+ )
+
+ with pytest.raises(RetryError) as e:
+ await analysis.understand_image_with_gptv(valid_figure)
+
+ root_cause = e.last_attempt.exception()
+ assert isinstance(root_cause, DummyOpenAIError)
+
+
+# ------------------------
+# Tests for analyse
+# ------------------------
+
+
+@pytest.mark.asyncio
+async def test_analyse_success(valid_figure, monkeypatch):
+ """
+ Test the successful execution of the analyse method.
+ Patch understand_image_with_gptv to return a figure with an updated description.
+ """
+ analysis = FigureAnalysis()
+ record = {"recordId": "rec1", "data": {"figure": valid_figure.model_dump()}}
+
+ async def dummy_understand(figure):
+ figure.description = "Updated Description"
+ return figure
+
+ monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
+ result = await analysis.analyse(record)
+ assert result["recordId"] == "rec1"
+ assert result["data"]["updated_figure"]["description"] == "Updated Description"
+ assert result["errors"] is None
+
+
+@pytest.mark.asyncio
+async def test_analyse_retry_rate_limit(valid_figure, monkeypatch):
+ """
+ Simulate a RetryError whose last attempt raised a RateLimitError.
+ The analyse method should return an error message indicating a rate limit error.
+ """
+ analysis = FigureAnalysis()
+ record = {"recordId": "rec2", "data": {"figure": valid_figure.model_dump()}}
+
+ # Create a mock request object
+ dummy_request = Request(
+ method="POST", url="https://api.openai.com/v1/chat/completions"
+ )
+
+ # Create a mock response object with the request set
+ dummy_response = Response(
+ status_code=429, content=b"Rate limit exceeded", request=dummy_request
+ )
+
+ # Create a RateLimitError instance
+ dummy_rate_error = RateLimitError(
+ message="Rate limit exceeded",
+ response=dummy_response,
+ body="Rate limit exceeded",
+ )
+ dummy_retry_error = RetryError(
+ last_attempt=MagicMock(exception=lambda: dummy_rate_error)
+ )
+
+ async def dummy_understand(figure):
+ raise dummy_retry_error
+
+ monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
+ result = await analysis.analyse(record)
+ assert result["recordId"] == "rec2"
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "rate limit error" in result["errors"][0]["message"].lower()
+
+
+@pytest.mark.asyncio
+async def test_analyse_general_exception(valid_figure, monkeypatch):
+ """
+ If understand_image_with_gptv raises a general Exception,
+ analyse should catch it and return an error response.
+ """
+ analysis = FigureAnalysis()
+ record = {"recordId": "rec3", "data": {"figure": valid_figure.model_dump()}}
+
+ async def dummy_understand(figure):
+ raise Exception("General error")
+
+ monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
+ result = await analysis.analyse(record)
+ assert result["recordId"] == "rec3"
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "check the logs for more details" in result["errors"][0]["message"].lower()
diff --git a/image_processing/tests/image_processing/test_layout_analysis.py b/image_processing/tests/image_processing/test_layout_analysis.py
new file mode 100644
index 0000000..e9de95a
--- /dev/null
+++ b/image_processing/tests/image_processing/test_layout_analysis.py
@@ -0,0 +1,493 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import os
+import pytest
+import tempfile
+import base64
+from unittest.mock import AsyncMock
+
+from layout_analysis import (
+ process_layout_analysis,
+ LayoutAnalysis,
+)
+
+
+# --- Dummy classes to simulate ADI results and figures ---
+class DummySpan:
+ def __init__(self, offset, length):
+ self.offset = offset
+ self.length = length
+
+
+class DummyPage:
+ def __init__(self, offset, length, page_number):
+ # Simulate a page span as a dictionary.
+ self.spans = [{"offset": offset, "length": length}]
+ self.page_number = page_number
+
+
+class DummyRegion:
+ def __init__(self, page_number):
+ self.page_number = page_number
+
+
+class DummyCaption:
+ def __init__(self, content):
+ self.content = content
+
+
+class DummyPoller:
+ def __init__(self, result, operation_id):
+ self._result = result
+ self.details = {"operation_id": operation_id}
+
+ async def result(self):
+ return self._result
+
+
+class DummyDocIntelligenceClient:
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+ async def begin_analyze_document(self, **kwargs):
+ # Create a dummy page spanning the first 5 characters.
+ dummy_page = DummyPage(0, 5, 1)
+ dummy_result = DummyResult("HelloWorld", pages=[dummy_page], figures=[])
+ return DummyPoller(dummy_result, "dummy_op")
+
+
+class DummyFigure:
+ def __init__(self, id, offset, length, page_number, caption_content):
+ self.id = id # note: process_figures_from_extracted_content checks "if figure.id is None"
+ self.bounding_regions = [DummyRegion(page_number)]
+ self.caption = DummyCaption(caption_content)
+ self.spans = [DummySpan(offset, length)]
+
+
+class DummyResult:
+ def __init__(self, content, pages, figures, model_id="model123"):
+ self.content = content
+ self.pages = pages
+ self.figures = figures
+ self.model_id = model_id
+
+
+# --- Dummy StorageAccountHelper for testing purposes ---
+class DummyStorageAccountHelper:
+ @property
+ def account_url(self):
+ return "http://dummy.storage"
+
+ async def upload_blob(self, container, blob, data, content_type):
+ # Simulate a successful upload returning a URL.
+ return f"http://dummy.url/{blob}"
+
+ async def download_blob_to_temp_dir(self, source, container, target_file_name):
+ # Write dummy content to a temp file and return its path along with empty metadata.
+ temp_file_path = os.path.join(tempfile.gettempdir(), target_file_name)
+ with open(temp_file_path, "wb") as f:
+ f.write(b"dummy file content")
+ return temp_file_path, {}
+
+ async def add_metadata_to_blob(self, source, container, metadata, upsert=False):
+ # Dummy method; do nothing.
+ return
+
+
+# --- Fixtures and environment setup ---
+@pytest.fixture(autouse=True)
+def set_env_vars(monkeypatch):
+ monkeypatch.setenv("StorageAccount__Name", "dummyaccount")
+ monkeypatch.setenv(
+ "AIService__DocumentIntelligence__Endpoint", "http://dummy.ai.endpoint"
+ )
+
+
+@pytest.fixture
+def dummy_storage_helper():
+ return DummyStorageAccountHelper()
+
+
+# --- Tests for LayoutAnalysis and process_layout_analysis ---
+
+
+def test_extract_file_info():
+ # Given a typical blob URL, extract_file_info should correctly set properties.
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.pdf"
+ la = LayoutAnalysis(record_id=1, source=source)
+ la.extract_file_info()
+ assert la.blob == "path/to/file.pdf"
+ assert la.container == "container"
+ assert la.images_container == "container-figures"
+ assert la.file_extension == "pdf"
+ assert la.target_file_name == "1.pdf"
+
+
+# Test non-page-wise analysis without figures.
+@pytest.mark.asyncio
+async def test_analyse_non_page_wise_no_figures(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=False, extract_figures=True, record_id=123, source=source
+ )
+ la.extract_file_info()
+ # Patch get_storage_account_helper to return our dummy helper.
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ # Patch download_blob_to_temp_dir to simulate a successful download.
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(return_value=("/tmp/dummy.txt", {})),
+ )
+ # Patch analyse_document to simulate a successful ADI analysis.
+ dummy_result = DummyResult(
+ content="Full document content", pages=[DummyPage(0, 21, 1)], figures=[]
+ )
+
+ async def dummy_analyse_document(file_path):
+ la.result = dummy_result
+ la.operation_id = "op123"
+
+ monkeypatch.setattr(la, "analyse_document", dummy_analyse_document)
+ # Patch process_figures_from_extracted_content to do nothing (since there are no figures).
+ monkeypatch.setattr(la, "process_figures_from_extracted_content", AsyncMock())
+ result = await la.analyse()
+ assert result["recordId"] == 123
+ data = result["data"]
+ # In non-page-wise mode, the output record is a NonPageWiseContentHolder
+ assert "layout" in data
+ layout = data["layout"]
+ assert layout["content"] == "Full document content"
+ # No figures were processed.
+ assert layout.get("figures", []) == []
+ assert result["errors"] is None
+
+
+# Test page-wise analysis without figures.
+@pytest.mark.asyncio
+async def test_analyse_page_wise_no_figures(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=True, extract_figures=True, record_id=456, source=source
+ )
+ la.extract_file_info()
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(return_value=("/tmp/dummy.txt", {})),
+ )
+ # Create a dummy result with one page and no figures.
+ dummy_page = DummyPage(0, 12, 1)
+ dummy_result = DummyResult(content="Page content", pages=[dummy_page], figures=[])
+
+ async def dummy_analyse_document(file_path):
+ la.result = dummy_result
+ la.operation_id = "op456"
+
+ monkeypatch.setattr(la, "analyse_document", dummy_analyse_document)
+ result = await la.analyse()
+ assert result["recordId"] == 456
+ data = result["data"]
+ # In page-wise mode, the output should have a "page_wise_layout" key.
+ assert "page_wise_layout" in data
+ layouts = data["page_wise_layout"]
+ assert len(layouts) == 1
+ layout = layouts[0]
+ # The content is extracted from dummy_result.content using the page span.
+ expected_content = dummy_result.content[0:12]
+ assert layout["content"] == expected_content
+ assert layout["page_number"] == 1
+ assert result["errors"] is None
+
+
+# Test page-wise analysis with figures (covering figure download and upload).
+@pytest.mark.asyncio
+async def test_analyse_page_wise_with_figures(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=True, extract_figures=True, record_id=789, source=source
+ )
+ la.extract_file_info()
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(return_value=("/tmp/dummy.txt", {})),
+ )
+ # Create a dummy page and a dummy figure.
+ dummy_page = DummyPage(0, 12, 1)
+ dummy_figure = DummyFigure(
+ "fig1", offset=5, length=5, page_number=1, caption_content="Caption text"
+ )
+ dummy_result = DummyResult(
+ content="Page content", pages=[dummy_page], figures=[dummy_figure]
+ )
+
+ async def dummy_analyse_document(file_path):
+ la.result = dummy_result
+ la.operation_id = "op789"
+
+ monkeypatch.setattr(la, "analyse_document", dummy_analyse_document)
+ # Patch download_figure_image to simulate downloading image bytes.
+ monkeypatch.setattr(
+ la, "download_figure_image", AsyncMock(return_value=b"fake_image")
+ )
+ # Patch upload_blob to simulate a successful upload.
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "upload_blob",
+ AsyncMock(return_value="http://dummy.url/fig1.png"),
+ )
+ result = await la.analyse()
+ assert result["recordId"] == 789
+ data = result["data"]
+ assert "page_wise_layout" in data
+ layouts = data["page_wise_layout"]
+ # The page layout should have a figures list containing our processed figure.
+ assert len(layouts) == 1
+ layout = layouts[0]
+ assert "figures" in layout
+ figures_list = layout["figures"]
+ assert len(figures_list) == 1
+ figure_data = figures_list[0]
+ assert figure_data["figure_id"] == "fig1"
+ # The data field should contain the base64-encoded image.
+ expected_b64 = base64.b64encode(b"fake_image").decode("utf-8")
+ assert figure_data["data"] == expected_b64
+ # Verify that the caption are set as expected.
+ assert figure_data["caption"] == "Caption text"
+ assert result["errors"] is None
+
+
+# Test failure during blob download.
+@pytest.mark.asyncio
+async def test_analyse_download_blob_failure(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=False, extract_figures=True, record_id=321, source=source
+ )
+ la.extract_file_info()
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ # Simulate a failure in download_blob_to_temp_dir.
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(side_effect=Exception("Download error")),
+ )
+ result = await la.analyse()
+ assert result["recordId"] == 321
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "Failed to download the blob" in result["errors"][0]["message"]
+
+
+# Test failure during analyse_document (simulate ADI failure) and ensure metadata is updated.
+@pytest.mark.asyncio
+async def test_analyse_document_failure(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=False, extract_figures=True, record_id=654, source=source
+ )
+ la.extract_file_info()
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(return_value=("/tmp/dummy.txt", {})),
+ )
+
+ # Simulate analyse_document throwing an exception.
+ async def dummy_analyse_document_failure(file_path):
+ raise Exception("Analyse document error")
+
+ monkeypatch.setattr(la, "analyse_document", dummy_analyse_document_failure)
+ # Track whether add_metadata_to_blob is called.
+ metadata_called = False
+
+ async def dummy_add_metadata(source, container, metadata, upsert=False):
+ nonlocal metadata_called
+ metadata_called = True
+
+ monkeypatch.setattr(
+ dummy_storage_helper, "add_metadata_to_blob", dummy_add_metadata
+ )
+ result = await la.analyse()
+ assert result["recordId"] == 654
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert (
+ "Failed to analyze the document with Azure Document Intelligence"
+ in result["errors"][0]["message"]
+ )
+ assert metadata_called is True
+
+
+# Test failure during processing of extracted content (e.g. page-wise content creation).
+@pytest.mark.asyncio
+async def test_analyse_processing_content_failure(monkeypatch, dummy_storage_helper):
+ source = "https://dummyaccount.blob.core.windows.net/container/path/to/file.txt"
+ la = LayoutAnalysis(
+ page_wise=True, extract_figures=True, record_id=987, source=source
+ )
+ la.extract_file_info()
+ monkeypatch.setattr(
+ la, "get_storage_account_helper", AsyncMock(return_value=dummy_storage_helper)
+ )
+ monkeypatch.setattr(
+ dummy_storage_helper,
+ "download_blob_to_temp_dir",
+ AsyncMock(return_value=("/tmp/dummy.txt", {})),
+ )
+ # Simulate a successful analyse_document.
+ dummy_page = DummyPage(0, 12, 1)
+ dummy_result = DummyResult(content="Page content", pages=[dummy_page], figures=[])
+
+ async def dummy_analyse_document(file_path):
+ la.result = dummy_result
+ la.operation_id = "op987"
+
+ monkeypatch.setattr(la, "analyse_document", dummy_analyse_document)
+
+ # Patch create_page_wise_content to raise an exception.
+ def raise_exception():
+ raise Exception("Processing error")
+
+ monkeypatch.setattr(la, "create_page_wise_content", raise_exception)
+ result = await la.analyse()
+ assert result["recordId"] == 987
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "Failed to process the extracted content" in result["errors"][0]["message"]
+
+
+# Test process_layout_analysis when 'source' is missing (KeyError branch).
+@pytest.mark.asyncio
+async def test_process_layout_analysis_missing_source():
+ record = {"recordId": "111", "data": {}} # missing 'source' key
+ result = await process_layout_analysis(record)
+ assert result["recordId"] == "111"
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "Pass a valid source" in result["errors"][0]["message"]
+
+
+@pytest.mark.asyncio
+async def test_analyse_document_success(monkeypatch, tmp_path):
+ # Create a temporary file with dummy content.
+ tmp_file = tmp_path / "dummy.txt"
+ tmp_file.write_bytes(b"dummy content")
+
+ la = LayoutAnalysis(
+ record_id=999,
+ source="https://dummyaccount.blob.core.windows.net/container/path/to/dummy.txt",
+ )
+
+ # Use an async function to return our dummy Document Intelligence client.
+ async def dummy_get_doc_intelligence_client():
+ return DummyDocIntelligenceClient()
+
+ monkeypatch.setattr(
+ la, "get_document_intelligence_client", dummy_get_doc_intelligence_client
+ )
+
+ await la.analyse_document(str(tmp_file))
+
+ assert la.result is not None
+ assert la.operation_id == "dummy_op"
+ # Check that the dummy result contains the expected content.
+ assert la.result.content == "HelloWorld"
+
+
+def test_create_page_wise_content():
+ # Test create_page_wise_content using a dummy result with one page.
+ la = LayoutAnalysis(record_id=100, source="dummy")
+
+ # Create a dummy result with content "HelloWorld"
+ # and a page with a span from index 0 with length 5.
+ class DummyResultContent:
+ pass
+
+ dummy_result = DummyResultContent()
+ dummy_result.content = "HelloWorld"
+ dummy_result.pages = [DummyPage(0, 5, 1)]
+ la.result = dummy_result
+
+ layouts = la.create_page_wise_content()
+ assert isinstance(layouts, list)
+ assert len(layouts) == 1
+ layout = layouts[0]
+ # The page content should be the substring "Hello"
+ assert layout.content == "Hello"
+ assert layout.page_number == 1
+ assert layout.page_offsets == 0
+
+
+def test_create_per_page_starting_sentence():
+ # Create a LayoutAnalysis instance.
+ la = LayoutAnalysis(record_id=200, source="dummy")
+
+ # Create a dummy result with content and pages.
+ # For this test, the first page's content slice will be "HelloWorld" (from index 0 with length 10),
+ # so the starting sentence extracted should be "HelloWorld".
+ class DummyResultContent:
+ pass
+
+ dummy_result = DummyResultContent()
+ dummy_result.content = "HelloWorld. This is a test sentence."
+ # DummyPage creates a page with spans as a list of dictionaries.
+ dummy_result.pages = [DummyPage(0, 10, 1)]
+ la.result = dummy_result
+
+ sentences = la.create_per_page_starting_sentence()
+ assert len(sentences) == 1
+ sentence = sentences[0]
+ assert sentence.page_number == 1
+ assert sentence.starting_sentence == "HelloWorld"
+
+
+def test_create_per_page_starting_sentence_multiple_pages():
+ # Create a LayoutAnalysis instance.
+ la = LayoutAnalysis(record_id=300, source="dummy")
+
+ # Create a dummy result with content spanning two pages.
+ # Use DummyPage to simulate pages; DummyPage expects "spans" as a list of dicts.
+ class DummyResultContent:
+ pass
+
+ dummy_result = DummyResultContent()
+ # Define content as two parts:
+ # Page 1: Offset 0, length 10 gives "Page one." (starting sentence "Page one")
+ # Page 2: Offset 10, length 15 gives " Page two text" (starting sentence " Page two text")
+ dummy_result.content = "Page one.Page two text and more content. This is more random content that is on page 2."
+ dummy_result.pages = [
+ DummyPage(0, 9, 1), # "Page one." (9 characters: indices 0-8)
+ DummyPage(9, 78, 2), # "Page two text and" (16 characters: indices 9-24)
+ ]
+ la.result = dummy_result
+
+ # Call create_per_page_starting_sentence and check results.
+ sentences = la.create_per_page_starting_sentence()
+ assert len(sentences) == 2
+
+ # For page 1, the substring is "Page one." -> split on "." gives "Page one"
+ assert sentences[0].page_number == 1
+ assert sentences[0].starting_sentence == "Page one"
+
+ # For page 2, the substring is "Page two text and" -> split on "." gives the entire string
+ assert sentences[1].page_number == 2
+ # We strip potential leading/trailing spaces for validation.
+ assert sentences[1].starting_sentence.strip() == "Page two text and more content"
diff --git a/image_processing/tests/image_processing/test_layout_and_figure_merger.py b/image_processing/tests/image_processing/test_layout_and_figure_merger.py
new file mode 100644
index 0000000..3deb271
--- /dev/null
+++ b/image_processing/tests/image_processing/test_layout_and_figure_merger.py
@@ -0,0 +1,114 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import pytest
+from layout_holders import LayoutHolder, FigureHolder
+from layout_and_figure_merger import LayoutAndFigureMerger
+
+
+@pytest.fixture
+def layout_holder():
+ return LayoutHolder(
+ content="This is a sample layout with a figure placeholder. This is a sentence after."
+ )
+
+
+@pytest.fixture
+def figure_holder():
+ return FigureHolder(
+ figure_id="12345",
+ description="Figure 1",
+ uri="https://example.com/12345.png",
+ offset=50,
+ length=17,
+ )
+
+
+@pytest.fixture
+def merger():
+ return LayoutAndFigureMerger()
+
+
+def test_insert_figure_description(merger, layout_holder, figure_holder):
+ updated_layout, inserted_length = merger.insert_figure_description(
+ layout_holder, figure_holder
+ )
+ assert "Figure 1" in updated_layout.content
+ assert (
+ inserted_length
+ == len("Figure 1") - figure_holder.length
+ )
+ assert (
+ updated_layout.content
+ == "This is a sample layout with a figure placeholder.Figure 1 This is a sentence after."
+ )
+
+
+def test_insert_figure_invalid_offset(merger, layout_holder):
+ invalid_figure = FigureHolder(
+ figure_id="12345",
+ offset=100,
+ length=5,
+ description="Invalid figure",
+ uri="https://example.com/12345.png",
+ )
+ with pytest.raises(ValueError, match="Figure offset is out of bounds"):
+ merger.insert_figure_description(layout_holder, invalid_figure)
+
+
+@pytest.mark.asyncio
+async def test_merge_figures_into_layout(merger, layout_holder, figure_holder):
+ figures = [figure_holder]
+ updated_layout = await merger.merge_figures_into_layout(layout_holder, figures)
+ assert "Figure 1" in updated_layout.content
+ assert (
+ updated_layout.content
+ == "This is a sample layout with a figure placeholder.Figure 1 This is a sentence after."
+ )
+
+
+@pytest.mark.asyncio
+async def test_merge_removes_irrelevant_figures(merger):
+ layout_holder = LayoutHolder(
+ content="Before 'Irrelevant Image' After"
+ )
+ updated_layout = await merger.merge_figures_into_layout(layout_holder, [])
+ assert "Irrelevant Image" not in updated_layout.content
+ assert "Before After" in updated_layout.content
+
+
+@pytest.mark.asyncio
+async def test_merge_removes_empty_figures(merger):
+ layout_holder = LayoutHolder(content="Before After")
+ updated_layout = await merger.merge_figures_into_layout(layout_holder, [])
+ assert "" not in updated_layout.content
+ assert "Before After" in updated_layout.content
+
+
+@pytest.mark.asyncio
+async def test_merge_removes_html_comments(merger):
+ layout_holder = LayoutHolder(content="Before After")
+ updated_layout = await merger.merge_figures_into_layout(layout_holder, [])
+ assert "" not in updated_layout.content
+ assert "Before After" in updated_layout.content
+
+
+@pytest.mark.asyncio
+async def test_merge_handles_exception(merger):
+ record = {
+ "recordId": "1",
+ "data": {
+ "layout": {"content": "Sample"},
+ "figures": [
+ {
+ "figure_id": "12345",
+ "offset": 1000,
+ "length": 5,
+ "description": "Invalid",
+ "uri": "https://example.com/12345.png",
+ }
+ ],
+ },
+ }
+ response = await merger.merge(record)
+ assert response["data"] is None
+ assert response["errors"] is not None
diff --git a/image_processing/tests/image_processing/test_layout_holders.py b/image_processing/tests/image_processing/test_layout_holders.py
new file mode 100644
index 0000000..3d2d1c4
--- /dev/null
+++ b/image_processing/tests/image_processing/test_layout_holders.py
@@ -0,0 +1,107 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import pytest
+from pydantic import ValidationError
+from layout_holders import (
+ FigureHolder,
+ LayoutHolder,
+ PageWiseContentHolder,
+ NonPageWiseContentHolder,
+ ChunkHolder,
+ PerPageStartingSentenceHolder,
+)
+
+
+def test_figure_holder_creation():
+ figure = FigureHolder(
+ FigureId="fig1",
+ offset=10,
+ length=5,
+ Uri="http://example.com/fig1.png",
+ Description="Sample figure",
+ )
+
+ assert figure.figure_id == "fig1"
+ assert figure.offset == 10
+ assert figure.length == 5
+ assert figure.uri == "http://example.com/fig1.png"
+ assert figure.description == "Sample figure"
+ assert figure.markdown == "Sample figure"
+
+
+def test_figure_holder_missing_required_fields():
+ with pytest.raises(ValidationError):
+ FigureHolder(offset=10, length=5, Uri="http://example.com/fig1.png")
+
+
+def test_layout_holder_creation():
+ layout = LayoutHolder(content="Sample content")
+ assert layout.content == "Sample content"
+ assert layout.page_number is None
+ assert layout.page_offsets == 0
+ assert layout.figures == []
+
+
+def test_layout_holder_with_figures():
+ figure = FigureHolder(
+ FigureId="fig1",
+ offset=10,
+ length=5,
+ Uri="http://example.com/fig1.png",
+ Description="Sample figure",
+ )
+ layout = LayoutHolder(content="Sample content", figures=[figure])
+ assert len(layout.figures) == 1
+ assert layout.figures[0].figure_id == "fig1"
+
+
+def test_page_wise_content_holder():
+ layout1 = LayoutHolder(content="Page 1")
+ layout2 = LayoutHolder(content="Page 2")
+ page_holder = PageWiseContentHolder(page_wise_layout=[layout1, layout2])
+ assert len(page_holder.page_wise_layout) == 2
+ assert page_holder.page_wise_layout[0].content == "Page 1"
+
+
+def test_non_page_wise_content_holder():
+ layout = LayoutHolder(content="Full document")
+ non_page_holder = NonPageWiseContentHolder(layout=layout)
+ assert non_page_holder.layout.content == "Full document"
+
+
+def test_chunk_holder_creation():
+ chunk = ChunkHolder(
+ mark_up="Sample markup",
+ sections=["Section1", "Section2"],
+ figures=[],
+ starting_sentence="First sentence",
+ cleaned_text="Cleaned text content",
+ page_number=1,
+ )
+ assert chunk.mark_up == "Sample markup"
+ assert chunk.sections == ["Section1", "Section2"]
+ assert chunk.starting_sentence == "First sentence"
+ assert chunk.cleaned_text == "Cleaned text content"
+ assert chunk.page_number == 1
+
+
+def test_per_page_starting_sentence_holder_creation():
+ sentence = PerPageStartingSentenceHolder(
+ page_number=1, starting_sentence="This is the starting sentence."
+ )
+ assert sentence.page_number == 1
+ assert sentence.starting_sentence == "This is the starting sentence."
+
+
+def test_non_page_wise_content_holder_with_sentences():
+ layout = LayoutHolder(content="Full document")
+ sentences = [
+ PerPageStartingSentenceHolder(page_number=1, starting_sentence="Start 1"),
+ PerPageStartingSentenceHolder(page_number=2, starting_sentence="Start 2"),
+ ]
+ non_page_holder = NonPageWiseContentHolder(
+ layout=layout, per_page_starting_sentences=sentences
+ )
+ assert non_page_holder.layout.content == "Full document"
+ assert len(non_page_holder.per_page_starting_sentences) == 2
+ assert non_page_holder.per_page_starting_sentences[0].starting_sentence == "Start 1"
diff --git a/image_processing/tests/image_processing/test_mark_up_cleaner.py b/image_processing/tests/image_processing/test_mark_up_cleaner.py
new file mode 100644
index 0000000..82497dc
--- /dev/null
+++ b/image_processing/tests/image_processing/test_mark_up_cleaner.py
@@ -0,0 +1,249 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import pytest
+from mark_up_cleaner import MarkUpCleaner
+from layout_holders import FigureHolder, ChunkHolder
+
+
+# Fixtures
+@pytest.fixture
+def cleaner():
+ return MarkUpCleaner()
+
+
+@pytest.fixture
+def sample_text():
+ return """
+ # Header 1
+ Some text.
+ ## Header 2
+ More text.
+
+ """
+
+
+@pytest.fixture
+def figures():
+ # We'll use the object-based representation for figures.
+ return [
+ FigureHolder(
+ FigureId="fig1",
+ offset=10,
+ length=5,
+ Uri="http://example.com/fig1.png",
+ Description="Sample figure",
+ ),
+ # This figure won't appear since its id won't be matched.
+ FigureHolder(
+ FigureId="12345",
+ offset=0,
+ length=8,
+ Uri="https://example.com/12345.png",
+ Description="Figure 1",
+ ),
+ ]
+
+
+# Test get_sections: It calls get_sections, then clean_sections internally.
+def test_get_sections(cleaner, sample_text):
+ sections = cleaner.get_sections(sample_text)
+ # Expecting headers extracted and cleaned.
+ assert sections == ["Header 1", "Header 2"]
+
+
+# Test get_figure_ids: using regex extraction.
+def test_get_figure_ids(cleaner, sample_text):
+ figure_ids = cleaner.get_figure_ids(sample_text)
+ assert figure_ids == ["12345"]
+
+
+# Test clean_sections: Remove leading hashes and extra chars.
+def test_clean_sections(cleaner):
+ sections = ["### Section 1", "## Section 2"]
+ cleaned = cleaner.clean_sections(sections)
+ assert cleaned == ["Section 1", "Section 2"]
+
+
+# Test remove_markdown_tags: Ensure tags are removed/replaced.
+def test_remove_markdown_tags(cleaner):
+ text = """
+ Some figure
+
+ # Header
+ Random sentence
+ """
+ tag_patterns = {
+ "figurecontent": r"",
+ "figure": r"(.*?)",
+ }
+ cleaned_text = cleaner.remove_markdown_tags(text, tag_patterns)
+ # Check that the inner contents are retained but tags removed.
+ assert "Some figure" in cleaned_text
+ assert "Some content" in cleaned_text
+ assert "" not in cleaned_text
+ assert "Some figure" not in cleaned_text
+
+
+# Test clean_text_and_extract_metadata: Pass a ChunkHolder instance
+def test_clean_text_and_extract_metadata(cleaner, sample_text, figures):
+ # Create a ChunkHolder from the sample text.
+ chunk = ChunkHolder(mark_up=sample_text)
+ result = cleaner.clean_text_and_extract_metadata(chunk, figures)
+ # result is a dict returned from model_dump (by alias)
+ assert isinstance(result, dict)
+ # The input text is stored under 'mark_up'
+ assert result["mark_up"] == sample_text
+ # get_sections should extract the headers.
+ assert result["sections"] == ["Header 1", "Header 2"]
+ # get_figure_ids returns ["12345"] so only the matching figure is kept.
+ assert len(result["figures"]) == 1
+ # FigureHolder uses alias "FigureId" for its id.
+ assert result["figures"][0]["FigureId"] == "12345"
+ # The cleaned text should have removed markup such as FigureId info.
+ assert "FigureId='12345'" not in result["cleaned_text"]
+
+
+# Async test for clean: using record dict with data holding a chunk sub-dict.
+@pytest.mark.asyncio
+async def test_clean(cleaner, sample_text):
+ record = {
+ "recordId": "1",
+ "data": {
+ "mark_up": sample_text,
+ "figures": [
+ {
+ "figure_id": "12345",
+ "uri": "https://example.com/12345.png",
+ "description": "Figure 1",
+ "offset": 0,
+ "length": 8,
+ },
+ {
+ "figure_id": "123456789",
+ "uri": "https://example.com/123456789.png",
+ "description": "Figure 2",
+ "offset": 10,
+ "length": 8,
+ },
+ ],
+ },
+ }
+ result = await cleaner.clean(record)
+ assert isinstance(result, dict)
+ assert result["recordId"] == "1"
+ # Ensure data was successfully cleaned
+ assert result["data"] is not None
+ assert result["data"]["cleaned_text"]
+ # Check that the expected keys are in the cleaned data.
+ assert "mark_up" in result["data"]
+ assert "sections" in result["data"]
+ assert "figures" in result["data"]
+ # Only one figure must match because get_figure_ids extracted "12345"
+ assert len(result["data"]["figures"]) == 1
+ assert result["data"]["figures"][0]["FigureId"] == "12345"
+
+
+# Test get_sections with empty text returns empty list.
+def test_get_sections_empty_text(cleaner):
+ sections = cleaner.get_sections("")
+ assert sections == []
+
+
+# Test get_figure_ids with no figure tags.
+def test_get_figure_ids_no_figures(cleaner):
+ text = "This text does not include any figures."
+ assert cleaner.get_figure_ids(text) == []
+
+
+# Test remove_markdown_tags with unknown tag patterns (should remain unchanged).
+def test_remove_markdown_tags_unknown_tag(cleaner):
+ text = "This is a basic text without markdown."
+ tag_patterns = {"nonexistent": r"(pattern)"}
+ result = cleaner.remove_markdown_tags(text, tag_patterns)
+ assert result == text
+
+
+# Test clean_text_and_extract_metadata with empty text: Should raise ValueError.
+def test_clean_text_and_extract_metadata_empty_text(cleaner, figures):
+ chunk = ChunkHolder(mark_up="")
+ with pytest.raises(ValueError):
+ cleaner.clean_text_and_extract_metadata(chunk, figures)
+
+
+# Async test: missing "chunk" key in record -> error branch of clean().
+@pytest.mark.asyncio
+async def test_clean_missing_chunk(cleaner):
+ record = {
+ "recordId": "3",
+ "data": {"figures": []},
+ }
+ result = await cleaner.clean(record)
+ assert result["recordId"] == "3"
+ assert result["data"] is None
+ assert result["errors"] is not None
+ assert "Failed to cleanup data" in result["errors"][0]["message"]
+
+
+# Async test: invalid figure structure causing an exception in clean()
+@pytest.mark.asyncio
+async def test_clean_with_invalid_figures_structure(cleaner):
+ record = {
+ "recordId": "4",
+ "data": {
+ "chunk": {"mark_up": "Some text with # Header"},
+ # Figures missing required keys for FigureHolder.
+ "figures": [{"invalid_key": "no_fig_id"}],
+ },
+ }
+ result = await cleaner.clean(record)
+ assert result["recordId"] == "4"
+ assert result["data"] is None
+ assert result["errors"] is not None
+
+
+def test_clean_only_figures_sets_page_number(cleaner):
+ # Input contains only a figure tag.
+ text = "I am a random description"
+ chunk = ChunkHolder(mark_up=text, page_number=1)
+ figs = [
+ FigureHolder(
+ FigureId="12345",
+ offset=0,
+ length=10,
+ Uri="http://example.com/12345.png",
+ Description="Figure 1",
+ page_number=2, # This page number should be picked up.
+ ),
+ FigureHolder(
+ FigureId="67890",
+ offset=20,
+ length=10,
+ Uri="http://example.com/67890.png",
+ Description="Figure 2",
+ page_number=4,
+ ),
+ ]
+ result = cleaner.clean_text_and_extract_metadata(chunk, figs)
+ # Because no text outside the figure tag is present, sections should be empty,
+ # and page_number should be set from the first matching figure.
+ assert result.get("sections") == []
+ assert result["page_number"] == 2
+
+
+def test_clean_text_with_mixed_content_leaves_page_number_unset(cleaner):
+ # Input contains text outside of the figure tag.
+ # Even though a figure appears, the presence of other text means page_number should not be auto-set as the chunk could overlap pages.
+ text = "More text before the figure. "
+ chunk = ChunkHolder(mark_up=text, page_number=4)
+ figs = [
+ FigureHolder(
+ FigureId="12345",
+ offset=0,
+ length=10,
+ Uri="http://example.com/12345.png",
+ Description="Figure 1",
+ page_number=5, # This should be ignored since text exists.
+ )
+ ]
+ result = cleaner.clean_text_and_extract_metadata(chunk, figs)
+ assert result.get("page_number") == 4
diff --git a/image_processing/tests/image_processing/test_semantic_text_chunker.py b/image_processing/tests/image_processing/test_semantic_text_chunker.py
new file mode 100644
index 0000000..59e8364
--- /dev/null
+++ b/image_processing/tests/image_processing/test_semantic_text_chunker.py
@@ -0,0 +1,355 @@
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+from semantic_text_chunker import (
+ process_semantic_text_chunker,
+ SemanticTextChunker,
+)
+
+# --- Dummy Classes for Process-Level Tests ---
+
+
+class DummyChunkHolder:
+ def __init__(self, mark_up, page_number=None):
+ self.mark_up = mark_up
+ self.page_number = page_number
+
+ def model_dump(self, by_alias=False):
+ return {"mark_up": self.mark_up, "page_number": self.page_number}
+
+
+class DummyPerPageStartingSentenceHolder:
+ def __init__(self, starting_sentence, page_number):
+ self.starting_sentence = starting_sentence
+ self.page_number = page_number
+
+
+# --- Process-Level Tests (Using Dummy Chunker) ---
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_success_without_page():
+ """Test a successful chunking when no per-page starting sentences are provided."""
+ record = {"recordId": "1", "data": {"content": "Some content to be chunked."}}
+
+ dummy_chunk = DummyChunkHolder("chunk1")
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(return_value=[dummy_chunk])
+ dummy_text_chunker.assign_page_number_to_chunks = MagicMock()
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "1"
+ assert result["data"] is not None
+ chunks = result["data"]["chunks"]
+ assert isinstance(chunks, list)
+ assert len(chunks) == 1
+ assert chunks[0]["mark_up"] == "chunk1"
+ # When no page info is provided, page_number remains unchanged (None in our dummy).
+ assert chunks[0]["page_number"] is None
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_success_with_page():
+ """Test a successful chunking when per-page starting sentences are provided and match a chunk."""
+ record = {
+ "recordId": "2",
+ "data": {
+ "content": "Some content to be chunked.",
+ "per_page_starting_sentences": [
+ {"starting_sentence": "chunk", "page_number": 5}
+ ],
+ },
+ }
+
+ dummy_chunk = DummyChunkHolder("This dummy chunk contains chunk in its text")
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(return_value=[dummy_chunk])
+
+ def dummy_assign_page(chunks, per_page_starting_sentences):
+ ps_objs = [
+ DummyPerPageStartingSentenceHolder(**ps.__dict__)
+ for ps in per_page_starting_sentences
+ ]
+ page_number = 1
+ for chunk in chunks:
+ for ps in ps_objs:
+ if ps.starting_sentence in chunk.mark_up:
+ page_number = ps.page_number
+ break
+ chunk.page_number = page_number
+ return chunks
+
+ dummy_text_chunker.assign_page_number_to_chunks = dummy_assign_page
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "2"
+ chunks = result["data"]["chunks"]
+ assert isinstance(chunks, list)
+ assert len(chunks) == 1
+ assert chunks[0]["page_number"] == 5
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_failure():
+ """Test that an exception during chunking is caught and returns an error record."""
+ record = {
+ "recordId": "3",
+ "data": {"content": "Content that will trigger an error."},
+ }
+
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(side_effect=Exception("Chunking error"))
+ dummy_text_chunker.assign_page_number_to_chunks = MagicMock()
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "3"
+ assert result["data"] is None
+ assert "errors" in result
+ assert isinstance(result["errors"], list)
+ assert result["errors"][0]["message"].startswith("Failed to chunk data")
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_multiple_chunks():
+ """
+ Test a record where chunk() returns multiple chunks and per-page starting sentences
+ assign different page numbers to different chunks.
+ """
+ record = {
+ "recordId": "4",
+ "data": {
+ "content": "Content that generates multiple chunks.",
+ "per_page_starting_sentences": [
+ {"starting_sentence": "first_page", "page_number": 3},
+ {"starting_sentence": "second_page", "page_number": 4},
+ ],
+ },
+ }
+
+ dummy_chunk1 = DummyChunkHolder("This chunk contains first_page indicator")
+ dummy_chunk2 = DummyChunkHolder("This chunk contains second_page indicator")
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(return_value=[dummy_chunk1, dummy_chunk2])
+
+ def dummy_assign_page(chunks, per_page_starting_sentences):
+ ps_objs = [
+ DummyPerPageStartingSentenceHolder(**ps.__dict__)
+ for ps in per_page_starting_sentences
+ ]
+ page_number = 1
+ for chunk in chunks:
+ for ps in ps_objs:
+ if ps.starting_sentence in chunk.mark_up:
+ page_number = ps.page_number
+ break
+ chunk.page_number = page_number
+ return chunks
+
+ dummy_text_chunker.assign_page_number_to_chunks = dummy_assign_page
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "4"
+ chunks = result["data"]["chunks"]
+ assert isinstance(chunks, list)
+ assert len(chunks) == 2
+ assert chunks[0]["page_number"] == 3
+ assert chunks[1]["page_number"] == 4
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_empty_page_sentences():
+ """
+ Test a record where 'per_page_starting_sentences' exists but is empty.
+ In this case, the default page (1) is assigned.
+ """
+ record = {
+ "recordId": "5",
+ "data": {
+ "content": "Some content to be chunked.",
+ "per_page_starting_sentences": [],
+ },
+ }
+
+ dummy_chunk = DummyChunkHolder("Chunk without any page indicator")
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(return_value=[dummy_chunk])
+
+ def dummy_assign_page(chunks, per_page_starting_sentences):
+ for chunk in chunks:
+ chunk.page_number = 1
+ return chunks
+
+ dummy_text_chunker.assign_page_number_to_chunks = dummy_assign_page
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "5"
+ chunks = result["data"]["chunks"]
+ assert isinstance(chunks, list)
+ assert len(chunks) == 1
+ assert chunks[0]["page_number"] == 1
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_missing_data():
+ """
+ Test that if the record is missing the 'data' key, the function returns an error.
+ """
+ record = {"recordId": "6"}
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(return_value=[DummyChunkHolder("chunk")])
+ dummy_text_chunker.assign_page_number_to_chunks = MagicMock()
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "6"
+ assert result["data"] is None
+ assert "errors" in result
+
+
+@pytest.mark.asyncio
+async def test_process_semantic_text_chunker_empty_content():
+ """
+ Test that if the content is empty and chunk() raises a ValueError (e.g. because no chunks were generated),
+ the error is handled and an error record is returned.
+ """
+ record = {"recordId": "7", "data": {"content": ""}}
+ dummy_text_chunker = MagicMock()
+ dummy_text_chunker.chunk = AsyncMock(
+ side_effect=ValueError("No chunks were generated")
+ )
+ dummy_text_chunker.assign_page_number_to_chunks = MagicMock()
+
+ result = await process_semantic_text_chunker(record, dummy_text_chunker)
+ assert result["recordId"] == "7"
+ assert result["data"] is None
+ assert "errors" in result
+ assert isinstance(result["errors"], list)
+ assert result["errors"][0]["message"].startswith("Failed to chunk data")
+
+
+# --- Helper Classes for Chunk Splitting Tests ---
+
+
+# A simple dummy spaCy-like model for sentence segmentation.
+class DummySpan:
+ def __init__(self, text):
+ self.text = text
+
+
+class DummyDoc:
+ def __init__(self, text):
+ # Naively split on period.
+ # (Ensure test texts include periods as sentence delimiters.)
+ sentences = [s.strip() for s in text.split(".") if s.strip()]
+ self.sents = [DummySpan(s) for s in sentences]
+
+
+class DummyNLP:
+ def __call__(self, text):
+ return DummyDoc(text)
+
+
+# Fixture that returns a SemanticTextChunker instance with patched components.
+@pytest.fixture
+def chunker():
+ # Use relaxed thresholds so that even short sentences qualify.
+ stc = SemanticTextChunker(
+ num_surrounding_sentences=1,
+ similarity_threshold=0.8,
+ max_chunk_tokens=1000,
+ min_chunk_tokens=1,
+ )
+ # Override the spaCy model with our dummy.
+ stc._nlp_model = DummyNLP()
+ # Override token counting to simply count words.
+ stc.num_tokens_from_string = lambda s: len(s.split())
+ # For these tests, assume all sentences are very similar (so merge_similar_chunks doesn’t force a split).
+ stc.sentence_similarity = lambda a, b: 1.0
+ return stc
+
+
+# --- Chunk Splitting Tests Using Real (Patched) Chunker ---
+
+
+@pytest.mark.asyncio
+async def test_chunk_complete_figure(chunker):
+ """
+ Test a text containing a complete
element.
+ Expect that the sentence with the complete figure is detected and grouped.
+ """
+ text = "Text before. Figure content. Text after."
+ chunks = await chunker.chunk(text)
+ # For our dummy segmentation, we expect two final chunks:
+ # one that combines "Text before" and the figure, and one for "Text after".
+ assert len(chunks) == 2
+ # Check that the first chunk contains a complete figure.
+ assert "
" in chunks[0].mark_up
+
+
+@pytest.mark.asyncio
+async def test_chunk_incomplete_figure(chunker):
+ """
+ Test a text with an incomplete figure element spanning multiple sentences.
+ The start and end of the figure should be grouped together.
+ """
+ text = (
+ "Text before. Start of figure. Figure continues . Text after."
+ )
+ chunks = await chunker.chunk(text)
+ # Expected grouping: one chunk combining the normal text and the grouped figure,
+ # and another chunk for text after.
+ assert len(chunks) == 2
+ # Check that the grouped chunk contains both the start and the end of the figure.
+ assert "
" in chunks[0].mark_up
+
+
+@pytest.mark.asyncio
+async def test_chunk_markdown_heading(chunker):
+ """
+ Test that a markdown heading is padded with newlines.
+ """
+ text = "Introduction. # Heading. More text."
+ chunks = await chunker.chunk(text)
+ # The heading should have been transformed to include "\n\n" before and after.
+ # Because merge_chunks may merge sentences, check that the final text contains the padded heading.
+ combined = " ".join(chunk.mark_up for chunk in chunks)
+ assert "\n\n# Heading\n\n" in combined
+
+
+@pytest.mark.asyncio
+async def test_chunk_table(chunker):
+ """
+ Test that a complete table element is detected.
+ """
+ text = "Before table.
Table content
. After table."
+ chunks = await chunker.chunk(text)
+ # Expect at least one chunk containing a complete table.
+ table_chunks = [
+ c.mark_up for c in chunks if "