Skip to content

Commit 3c893a3

Browse files
committed
expaned chunking to more different blocks
1 parent 71a5a3b commit 3c893a3

5 files changed

+167
-77
lines changed

kaizen/retriever/code_chunker.py

+84-16
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,100 @@
77
def chunk_code(code: str, language: str) -> ParsedBody:
88
parser = ParserFactory.get_parser(language)
99
tree = parser.parse(code.encode("utf8"))
10-
10+
code_bytes = code.encode("utf8")
1111
body: ParsedBody = {
12+
"imports": [],
13+
"global_variables": [],
14+
"type_definitions": [],
1215
"functions": {},
16+
"async_functions": {},
1317
"classes": {},
1418
"hooks": {},
1519
"components": {},
20+
"jsx_elements": [],
1621
"other_blocks": [],
1722
}
18-
# code_bytes = code.encode("utf8")
1923

2024
def process_node(node):
21-
result = parse_code(code, language)
25+
result = parse_code(node, code_bytes)
2226
if result:
23-
# Assuming parse_code is modified to return line numbers
2427
start_line = result.get("start_line", 0)
2528
end_line = result.get("end_line", 0)
2629

27-
if result["type"] == "function":
30+
if result["type"] == "import_statement":
31+
body["imports"].append(
32+
{
33+
"code": result["code"],
34+
"start_line": start_line,
35+
"end_line": end_line,
36+
}
37+
)
38+
elif (
39+
result["type"] == "variable_declaration"
40+
and node.parent.type == "program"
41+
):
42+
body["global_variables"].append(
43+
{
44+
"code": result["code"],
45+
"start_line": start_line,
46+
"end_line": end_line,
47+
}
48+
)
49+
elif result["type"] in ["type_alias", "interface_declaration"]:
50+
body["type_definitions"].append(
51+
{
52+
"name": result["name"],
53+
"code": result["code"],
54+
"start_line": start_line,
55+
"end_line": end_line,
56+
}
57+
)
58+
elif result["type"] == "function":
2859
if is_react_hook(result["name"]):
2960
body["hooks"][result["name"]] = {
3061
"code": result["code"],
3162
"start_line": start_line,
3263
"end_line": end_line,
3364
}
3465
elif is_react_component(result["code"]):
35-
body["components"][result["name"]] = result["code"]
66+
body["components"][result["name"]] = {
67+
"code": result["code"],
68+
"start_line": start_line,
69+
"end_line": end_line,
70+
}
71+
elif "async" in result["code"].split()[0]:
72+
body["async_functions"][result["name"]] = {
73+
"code": result["code"],
74+
"start_line": start_line,
75+
"end_line": end_line,
76+
}
3677
else:
37-
body["functions"][result["name"]] = result["code"]
78+
body["functions"][result["name"]] = {
79+
"code": result["code"],
80+
"start_line": start_line,
81+
"end_line": end_line,
82+
}
3883
elif result["type"] == "class":
3984
if is_react_component(result["code"]):
40-
body["components"][result["name"]] = result["code"]
85+
body["components"][result["name"]] = {
86+
"code": result["code"],
87+
"start_line": start_line,
88+
"end_line": end_line,
89+
}
4190
else:
42-
body["classes"][result["name"]] = result["code"]
43-
elif result["type"] == "component":
44-
body["components"][result["name"]] = result["code"]
45-
elif result["type"] == "impl":
46-
body["classes"][result["name"]] = result["code"]
91+
body["classes"][result["name"]] = {
92+
"code": result["code"],
93+
"start_line": start_line,
94+
"end_line": end_line,
95+
}
96+
elif result["type"] == "jsx_element":
97+
body["jsx_elements"].append(
98+
{
99+
"code": result["code"],
100+
"start_line": start_line,
101+
"end_line": end_line,
102+
}
103+
)
47104
else:
48105
for child in node.children:
49106
process_node(child)
@@ -55,8 +112,14 @@ def process_node(node):
55112
for section in body.values():
56113
if isinstance(section, dict):
57114
for code_block in section.values():
58-
start = code.index(code_block)
59-
collected_ranges.append((start, start + len(code_block)))
115+
collected_ranges.append(
116+
(code_block["start_line"], code_block["end_line"])
117+
)
118+
elif isinstance(section, list):
119+
for code_block in section:
120+
collected_ranges.append(
121+
(code_block["start_line"], code_block["end_line"])
122+
)
60123

61124
collected_ranges.sort()
62125
last_end = 0
@@ -76,5 +139,10 @@ def is_react_hook(name: str) -> bool:
76139

77140
def is_react_component(code: str) -> bool:
78141
return (
79-
"React" in code or "jsx" in code.lower() or "tsx" in code.lower() or "<" in code
142+
"React" in code
143+
or "jsx" in code.lower()
144+
or "tsx" in code.lower()
145+
or "<" in code
146+
or "props" in code
147+
or "render" in code
80148
)

kaizen/retriever/llama_index_retriever.py

+75-54
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llama_index.embeddings.litellm import LiteLLMEmbedding
1313
from sqlalchemy import create_engine, text
1414
from kaizen.retriever.qdrant_vector_store import QdrantVectorStore
15-
15+
import json
1616

1717
# Set up logging
1818
logging.basicConfig(
@@ -43,10 +43,18 @@ def __init__(self, repo_id=1):
4343
)
4444
logger.info("RepositoryAnalyzer initialized successfully")
4545

46-
def setup_repository(self, repo_path: str, node_query: str = None):
46+
def setup_repository(
47+
self,
48+
repo_path: str,
49+
node_query: str = None,
50+
file_query: str = None,
51+
function_query: str = None,
52+
):
4753
self.total_usage = self.llm_provider.DEFAULT_USAGE
4854
self.total_files_processed = 0
4955
self.node_query = node_query
56+
self.file_query = file_query
57+
self.function_query = function_query
5058
self.embedding_usage = {"prompt_tokens": 10, "total_tokens": 10}
5159
logger.info(f"Starting repository setup for: {repo_path}")
5260
self.parse_repository(repo_path)
@@ -130,7 +138,7 @@ def process_code_block(
130138
return # Skip this code block
131139

132140
language = self.get_language_from_extension(file_path)
133-
abstraction, usage = self.generate_abstraction(code, language)
141+
abstraction, usage = self.generate_abstraction(code, language, section)
134142
self.total_usage = self.llm_provider.update_usage(
135143
total_usage=self.total_usage, current_usage=usage
136144
)
@@ -185,70 +193,85 @@ def store_abstraction_and_embedding(self, function_id: int, abstraction: str):
185193
logger.debug(f"Abstraction and embedding stored for function_id: {function_id}")
186194

187195
def generate_abstraction(
188-
self, code_block: str, language: str, max_tokens: int = 300
196+
self, code_block: str, language: str, section: str, max_tokens: int = 300
189197
) -> str:
190198
prompt = f"""Analyze the following {language} code block and generate a structured abstraction.
191-
Your response should be in YAML format and include the following sections:
199+
Your response should be in JSON format and include the following sections:
200+
201+
{{
202+
"summary": "A concise one-sentence summary of the function's primary purpose.",
192203
193-
summary: A concise one-sentence summary of the function's primary purpose.
204+
"functionality": "A detailed explanation of what the function does, including its main steps and logic. Use multiple lines if needed for clarity.",
194205
195-
functionality: |
196-
A detailed explanation of what the function does, including its main steps and logic.
197-
Use multiple lines if needed for clarity.
206+
"inputs": [
207+
{{
208+
"name": "The parameter name",
209+
"type": "The parameter type",
210+
"description": "A brief description of the parameter's purpose",
211+
"default_value": "The default value, if any (or null if not applicable)"
212+
}}
213+
],
198214
199-
inputs:
200-
- name: The parameter name
201-
type: The parameter type
202-
description: A brief description of the parameter's purpose
203-
default_value: The default value, if any (or null if not applicable)
215+
"output": {{
216+
"type": "The return type of the function",
217+
"description": "A description of what is returned and under what conditions. Use multiple lines if needed."
218+
}},
204219
205-
output:
206-
type: The return type of the function
207-
description: |
208-
A description of what is returned and under what conditions.
209-
Use multiple lines if needed.
220+
"dependencies": [
221+
{{
222+
"name": "Name of the external library or module",
223+
"purpose": "Brief explanation of its use in this function"
224+
}}
225+
],
210226
211-
dependencies:
212-
- name: Name of the external library or module
213-
purpose: Brief explanation of its use in this function
227+
"algorithms": [
228+
{{
229+
"name": "Name of the algorithm or data structure",
230+
"description": "Brief explanation of its use and importance"
231+
}}
232+
],
214233
215-
algorithms:
216-
- name: Name of the algorithm or data structure
217-
description: Brief explanation of its use and importance
234+
"edge_cases": [
235+
"A list of potential edge cases or special conditions the function handles or should handle"
236+
],
218237
219-
edge_cases:
220-
- A list of potential edge cases or special conditions the function handles or should handle
238+
"error_handling": "A description of how errors are handled or propagated. Include specific error types if applicable.",
221239
222-
error_handling: |
223-
A description of how errors are handled or propagated.
224-
Include specific error types if applicable.
240+
"usage_context": "A brief explanation of how this function might be used by parent functions or in a larger system. Include typical scenarios and any important considerations for its use.",
225241
226-
usage_context: |
227-
A brief explanation of how this function might be used by parent functions or in a larger system.
228-
Include typical scenarios and any important considerations for its use.
242+
"complexity": {{
243+
"time": "Estimated time complexity (e.g., O(n))",
244+
"space": "Estimated space complexity (e.g., O(1))",
245+
"explanation": "Brief explanation of the complexity analysis"
246+
}},
229247
230-
complexity:
231-
time: Estimated time complexity (e.g., O(n))
232-
space: Estimated space complexity (e.g., O(1))
248+
"tags": ["List", "of", "relevant", "tags"],
233249
234-
code_snippet: |
235-
```{language}
236-
{code_block}
237-
```
250+
"testing_considerations": "Suggestions for unit tests or test cases to cover key functionality and edge cases",
238251
239-
Provide your analysis in this clear, structured YAML format. If any section is not applicable, use an empty list [] or null value as appropriate. Ensure that multi-line descriptions are properly indented under their respective keys.
252+
"version_compatibility": "Information about language versions or dependency versions this code is compatible with",
253+
254+
"performance_considerations": "Any notes on performance optimizations or potential bottlenecks",
255+
256+
"security_considerations": "Any security-related notes or best practices relevant to this code",
257+
258+
"maintainability_score": "A subjective score from 1-10 on how easy the code is to maintain, with a brief explanation"
259+
}}
260+
261+
Provide your analysis in this clear, structured JSON format. If any section is not applicable, use an empty list [] or null value as appropriate. Ensure that multi-line descriptions are properly formatted as strings.
240262
241263
Code to analyze:
242-
```{language}
243-
{code_block}
244-
```
264+
Language: {language}
265+
Block Type: {section}
266+
Code Block:
267+
```{code_block}```
245268
"""
246269

247270
estimated_prompt_tokens = len(tokenizer.encode(prompt))
248271
adjusted_max_tokens = min(max(150, estimated_prompt_tokens), 1000)
249272

250273
try:
251-
abstraction, usage = self.llm_provider.chat_completion(
274+
abstraction, usage = self.llm_provider.chat_completion_with_json(
252275
prompt="",
253276
messages=[
254277
{
@@ -259,7 +282,7 @@ def generate_abstraction(
259282
],
260283
custom_model={"max_tokens": adjusted_max_tokens, "model": "small"},
261284
)
262-
return abstraction, usage
285+
return json.dumps(abstraction), usage
263286

264287
except Exception as e:
265288
raise e
@@ -272,21 +295,19 @@ def store_code_in_db(
272295
section: str,
273296
name: str,
274297
start_line: int,
275-
file_query: str = None,
276-
function_query: str = None,
277298
) -> int:
278299
logger.debug(f"Storing code in DB: {file_path} - {section} - {name}")
279300
with self.engine.begin() as connection:
280301
# Insert into files table (assuming this part is already correct)
281-
if not file_query:
282-
file_query = """
302+
if not self.file_query:
303+
self.file_query = """
283304
INSERT INTO files (repo_id, file_path, file_name, file_ext, programming_language)
284305
VALUES (:repo_id, :file_path, :file_name, :file_ext, :programming_language)
285306
ON CONFLICT (repo_id, file_path) DO UPDATE SET file_path = EXCLUDED.file_path
286307
RETURNING file_id
287308
"""
288309
file_id = connection.execute(
289-
text(file_query),
310+
text(self.file_query),
290311
{
291312
"repo_id": self.repo_id,
292313
"file_path": file_path,
@@ -297,15 +318,15 @@ def store_code_in_db(
297318
).scalar_one()
298319

299320
# Insert into function_abstractions table
300-
if not function_query:
301-
function_query = """
321+
if not self.function_query:
322+
self.function_query = """
302323
INSERT INTO function_abstractions
303324
(file_id, function_name, function_signature, abstract_functionality, start_line, end_line)
304325
VALUES (:file_id, :function_name, :function_signature, :abstract_functionality, :start_line, :end_line)
305326
RETURNING function_id
306327
"""
307328
function_id = connection.execute(
308-
text(function_query),
329+
text(self.function_query),
309330
{
310331
"file_id": file_id,
311332
"function_name": name,

kaizen/retriever/qdrant_vector_store.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class QdrantVectorStore:
1111
def __init__(self, collection_name, vector_size, max_retries=3, retry_delay=2):
1212
self.HOST = os.getenv("QDRANT_HOST", "localhost")
1313
self.PORT = os.getenv("QDRANT_PORT", "6333")
14+
self.QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
1415
self.collection_name = collection_name
1516
self.max_retries = max_retries
1617
self.retry_delay = retry_delay
@@ -49,7 +50,9 @@ def _create_collection(self, vector_size):
4950

5051
def add(self, nodes):
5152
points = [
52-
PointStruct(id=node.id_, vector=node.embedding, payload=node.metadata)
53+
PointStruct(
54+
id=node["id"], vector=node["embedding"], payload=node["metadata"]
55+
)
5356
for node in nodes
5457
]
5558
self.client.upsert(collection_name=self.collection_name, points=points)

kaizen/retriever/tree_sitter_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ def traverse_tree(node, code_bytes: bytes) -> Dict[str, Any]:
101101
return None
102102

103103

104-
def parse_code(code: str, language: str) -> Dict[str, Any]:
104+
def parse_code(node: Any, code_bytes: bytes) -> Dict[str, Any]:
105105
try:
106-
parser = ParserFactory.get_parser(language)
107-
tree = parser.parse(bytes(code, "utf8"))
108-
return traverse_tree(tree.root_node, code.encode("utf8"))
106+
return traverse_tree(node, code_bytes)
109107
except Exception as e:
110-
logger.error(f"Failed to parse {language} code: {str(e)}")
108+
logger.error(f"Failed to parse code: {str(e)}")
111109
raise
112110

113111

0 commit comments

Comments
 (0)