@@ -69,18 +69,43 @@ def generate_visualization_schemas():
69
69
id = call_llm_plotly_schema (json .dumps (scenario_data ), id )
70
70
71
71
def generate_detailed_visualization_schemas ():
72
+ print ('Generating detailed schemas' )
72
73
scenario_dir = 'generated_schema'
73
74
if not os .path .exists (scenario_dir ):
74
75
print ("Populate the schemas first by calling generate_visualization_schemas()" )
75
- id = 253
76
+ id = 303
76
77
sample_size = min (25 , len (os .listdir (scenario_dir )))
77
- random_files = random .sample (os .listdir (scenario_dir ), sample_size )
78
+ chart_types = {}
79
+ for file_name in os .listdir (scenario_dir ):
80
+ json_data = read_json_file (os .path .join (scenario_dir , file_name ))
81
+ chart_type = json_data ['data' ][0 ]['type' ]
82
+ if chart_type not in chart_types :
83
+ chart_types [chart_type ] = []
84
+ chart_types [chart_type ].append (file_name )
85
+
86
+ min_samples_per_chart_type = sample_size // len (chart_types )
87
+ random_files = []
88
+ for chart_type , files in chart_types .items ():
89
+ random_files .extend (random .sample (files , min (min_samples_per_chart_type , len (files ))))
90
+
91
+ # keeping extra buffer data in case any one fails to generate
92
+ buffer_data = 50
93
+ if len (random_files ) < sample_size + buffer_data :
94
+ remaining_files = [file for file in os .listdir (scenario_dir ) if file not in random_files ]
95
+ if len (remaining_files ) > 0 :
96
+ random_files .extend (random .sample (remaining_files , ((sample_size + buffer_data )- len (random_files ))))
97
+
98
+ count = 0
78
99
for file_name in random_files :
100
+ if count == sample_size :
101
+ break
79
102
json_data = read_json_file (os .path .join (scenario_dir , file_name ))
80
103
curr_id = json_data ['id' ]
81
104
file_name_prefix = file_name .split ('.' )[0 ]
82
105
suffix = file_name_prefix .split (str (curr_id ))[1 ]
83
- id = call_llm_detailed_plotly_schema (json .dumps (json_data ), id , suffix )
106
+ id , isSuccess = call_llm_detailed_plotly_schema (json .dumps (json_data ), id , suffix )
107
+ if isSuccess :
108
+ count = count + 1
84
109
85
110
def generate_locale_visualization_schemas ():
86
111
if not os .path .exists (scenario_dir ):
@@ -303,7 +328,7 @@ def call_llm_detailed_plotly_schema(scenario: str, id: int, suffix: str):
303
328
# call only if the file does not exist
304
329
if os .path .exists (f'generated_schema_detailed/data_{ id } { suffix } .json' ):
305
330
print (f"Skipping { id } _{ suffix } " )
306
- return id + 1
331
+ return id + 1 , False
307
332
308
333
# in case text_output is not a valid json, it will retry 3 times
309
334
retry_count = 0
@@ -318,7 +343,7 @@ def call_llm_detailed_plotly_schema(scenario: str, id: int, suffix: str):
318
343
319
344
if retry_count == 3 :
320
345
print ("Failed to generate schema" )
321
- return id
346
+ return id , False
322
347
323
348
output_dir = 'generated_schema_detailed'
324
349
os .makedirs (output_dir , exist_ok = True )
@@ -329,7 +354,7 @@ def call_llm_detailed_plotly_schema(scenario: str, id: int, suffix: str):
329
354
json .dump (data , file , indent = 4 )
330
355
id = id + 1
331
356
332
- return id
357
+ return id , True
333
358
334
359
def get_chart_type_from_image ():
335
360
directory_path = os .path .join ('..' , 'tests' , 'Plotly.spec.ts-snapshots' )
@@ -390,10 +415,10 @@ def get_chart_type_from_image():
390
415
# generate_visualization_schemas()
391
416
392
417
# Generate detailed schemas
393
- # generate_detailed_visualization_schemas()
418
+ generate_detailed_visualization_schemas ()
394
419
395
420
# Generate locale based schemas
396
421
# generate_locale_visualization_schemas()
397
422
398
423
# Generate chart types from screenshots taken by Playwright
399
- get_chart_type_from_image ()
424
+ # get_chart_type_from_image()
0 commit comments