Source code for claf.data.reader.wikisql


import json
import logging
from pathlib import Path
import uuid

from overrides import overrides
from tqdm import tqdm

from claf.data.dataset import WikiSQLDataset
from claf.data.dto import Helper
from claf.data.reader.base import DataReader
from claf.data import utils
from claf.decorator import register
from claf.metric.wikisql_lib.dbengine import DBEngine
from claf.metric.wikisql_lib.query import Query

logger = logging.getLogger(__name__)


[docs]@register("reader:wikisql") class WikiSQLReader(DataReader): """ WikiSQL DataReader (http://arxiv.org/abs/1709.00103) * Args: file_paths: .json file paths (train and dev) tokenizers: defined tokenizers config (char/word) """ def __init__(self, file_paths, tokenizers, context_max_length=None, is_test=None): super(WikiSQLReader, self).__init__(file_paths, WikiSQLDataset) self.is_test = is_test self.text_columns = ["column", "question"] if "word" not in tokenizers: raise ValueError("WordTokenizer is required. define English WordTokenizer") self.word_tokenizer = tokenizers["word"] self.dbengine = None @overrides def _read(self, file_path, data_type=None): file_path = self.data_handler.read(file_path, return_path=True) file_path = Path(file_path) data_dir = file_path.parent file_name = file_path.stem db_path = data_dir / f"{file_name}.db" table_path = data_dir / f"{file_name}.tables.jsonl" self.dbengine = DBEngine(db_path) helper = Helper(**{ "file_path": file_path, "db_path": db_path, }) features, labels = [], [] sql_datas, table_data = self.load_data(file_path, table_path, data_type=data_type) for sql_data in tqdm(sql_datas, desc=data_type): question = sql_data["question"] table_id = sql_data["table_id"] column_headers = table_data[table_id]["header"] feature_row = {"column": column_headers, "question": question} data_uid = str(uuid.uuid1()) conditions_value_position = self.get_coditions_value_position( sql_data["question"], [x[2] for x in sql_data["sql"]["conds"]] ) sql_query = Query.from_dict(sql_data["sql"], ordered=True) execution_result = self.dbengine.execute_query(table_id, sql_query, lower=True) label_row = { "id": data_uid, "table_id": table_id, "tokenized_question": self.word_tokenizer.tokenize(question), "aggregator_idx": sql_data["sql"]["agg"], "select_column_idx": sql_data["sql"]["sel"], "conditions_num": len(sql_data["sql"]["conds"]), "conditions_column_idx": [x[0] for x in sql_data["sql"]["conds"]], "conditions_operator_idx": [x[1] for x in sql_data["sql"]["conds"]], "conditions_value_string": [str(x[2]) for x in sql_data["sql"]["conds"]], "conditions_value_position": conditions_value_position, "sql_query": sql_query, "execution_result": execution_result, } features.append(feature_row) labels.append(label_row) helper.set_example(data_uid, { "question": question, "sql_query": sql_query, "execution_result": execution_result, }) if self.is_test and len(labels) == 10: break return utils.make_batch(features, labels), helper.to_dict()
[docs] @overrides def read_one_example(self, inputs): """ inputs keys: question, column, db_path, table_id """ question_text = inputs["question"] helper = {"tokenized_question": self.word_tokenizer.tokenize(question_text)} return inputs, helper
[docs] def load_data(self, sql_path, table_path, data_type=None): sql_data = [] table_data = {} logger.info(f"Loading data from {sql_path}") with open(sql_path) as inf: for line in tqdm(inf, desc=f"sql_{data_type}"): sql = json.loads(line.strip()) sql_data.append(sql) logger.info(f"Loading data from {table_path}") with open(table_path) as inf: for line in tqdm(inf, desc=f"table_{data_type}"): tab = json.loads(line.strip()) table_data[tab["id"]] = tab for sql in sql_data: assert sql["table_id"] in table_data return sql_data, table_data
[docs] def get_coditions_value_position(self, question, values): tokenized_question = self.word_tokenizer.tokenize(question.lower()) tokenized_values = [self.word_tokenizer.tokenize(str(value).lower()) for value in values] START_TOKEN, END_TOKEN = "<BEG>", "<END>" token_to_index = {START_TOKEN: 0} for token in tokenized_question: token_to_index[token] = len(token_to_index) token_to_index[END_TOKEN] = len(token_to_index) position_tokens = [] for value in tokenized_values: position_token = [token_to_index[START_TOKEN]] for token in value: if token in token_to_index: position_token.append(token_to_index[token]) else: for i in range(len(tokenized_question)): q_token = tokenized_question[i] if token in q_token: position_token.append(token_to_index[q_token]) position_token.append(token_to_index[END_TOKEN]) assert len(position_token) != 2 position_tokens.append(position_token) return position_tokens