-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_to_table.py
130 lines (112 loc) · 4.34 KB
/
text_to_table.py
1
import osfrom TabularSemanticParsing.src.parse_args import argsfrom TabularSemanticParsing.src.data_processor.schema_graph import SchemaGraphfrom TabularSemanticParsing.src.demos.demos import Text2SQLWrapperfrom TabularSemanticParsing.src.trans_checker.args import args as cs_argsimport TabularSemanticParsing.src.utils.utils as utilsimport sqlite3import pandas as pdimport refrom ca_es_to_en import translate_es, translate_ca# Set model IDargs.model_id = utils.model_index[args.model]assert (args.model_id is not None)def setup(csv_name, delimiter=','): global t2sql, schema, db_path csv_dir = args.csv_dir db_path = os.path.join(csv_dir, '{}.sqlite'.format(csv_name)) print('* db_path = ' + db_path + '\n') if os.path.exists(db_path): os.remove(db_path) schema = SchemaGraph(csv_name, db_path=db_path) csv_path = os.path.join(csv_dir, '{}.csv'.format(csv_name)) conn = sqlite3.connect(db_path) csv = pd.read_csv(csv_path, sep=delimiter) print('* rows: ' + str(csv.shape[0]) + ', columns: ' + str(csv.shape[1]) + '\n') csv.to_sql(csv_name, conn, if_exists='append', index=False) conn.close() # in_type = os.path.join(csv_dir, '{}.types'.format(csv_name)) schema.load_data_from_csv_file(csv_path, delimiter) # , in_type) schema.pretty_print() t2sql = Text2SQLWrapper(args, cs_args, schema)def addFieldsToSql(fields, sql_query): select_all_from = re.search(r'SELECT \* FROM', sql_query) if select_all_from: select_fields_str = 'SELECT ' + fields[0] for f in fields[1:]: select_fields_str += ', {}'.format(f) select_fields_str += ' FROM' return select_fields_str + sql_query[select_all_from.span(0)[1]:] else: return sql_querydef addFiltersToSql(filters, sql_query): select_from_where_orderby = re.search(r'SELECT (.)* FROM (.)* WHERE (.)* ORDER BY (.)*', sql_query) select_from_orderby = re.search(r'SELECT (.)* FROM (.)* ORDER BY (.)*', sql_query) select_from_where = re.search(r'SELECT (.)* FROM (.)* WHERE (.)*', sql_query) select_from = re.search(r'SELECT (.)* FROM (.)*', sql_query) orderby_str = '' orderby = re.search(r'ORDER BY (.)*', sql_query) if select_from_where_orderby: orderby_str = sql_query[orderby.span(0)[0]: orderby.span(0)[1]] sql_query = sql_query[:orderby.span(0)[0]] elif select_from_orderby: orderby_str = sql_query[orderby.span(0)[0]: orderby.span(0)[1]] sql_query = sql_query[:orderby.span(0)[0]] + ' WHERE 1=1' elif select_from_where: True elif select_from: sql_query += ' WHERE 1=1' for f in filters: sql_query += ' AND ({})'.format(f) if orderby: sql_query += ' ' + orderby_str return sql_querydef getTable(body): input_text_original = body['input'] language = body['language'] fields = body['fields'] filters = body['filters'] ignoreCase = body['ignoreCase'] input_text = input_text_original if language == 'ca': input_text = translate_ca(input_text_original) elif language == 'es': input_text = translate_es(input_text_original) output = t2sql.process(input_text, schema.name) sql_query = output['sql_query'] if sql_query is None: sql_query = '' header = [] table = [] else: if fields: sql_query = addFieldsToSql(fields, sql_query) if filters: sql_query = addFiltersToSql(filters, sql_query) if ignoreCase and re.search(r'SELECT (.)* FROM (.)* WHERE (.)*', sql_query): sql_query += ' COLLATE NOCASE' conn = sqlite3.connect(db_path) c = conn.cursor() header = [] table = [] try: c.execute(sql_query) for column in c.description: header.append(column[0]) table = c.fetchall() except Exception as e: print(e) print('* input = ' + input_text_original) print('* input in english = ' + input_text) print('* sql = ' + sql_query) print('* header = ' + str(header)) print('* table = ' + str(table)) print() response = { 'input': input_text_original, 'input_en': input_text, 'sql': sql_query, 'header': header, 'table': table } return response, 200