import numpy as np
import torch
[docs]def encode_column(column_embed, column_name_mask, rnn_module):
B, C_L, N_L, embed_D = list(column_embed.size())
column_lengths = get_column_lengths(column_embed, column_name_mask)
column_last_index = column_lengths - column_lengths.gt(0).long() # NOTE: hard-code
column_reshape = [-1] + [N_L, embed_D]
column_embed = column_embed.view(*column_reshape)
encoded_column, _ = rnn_module(column_embed)
encoded_D = encoded_column.size(-1)
encoded_output_column = torch.cat(
[
torch.index_select(encoded_column[i], 0, column_last_index[i])
for i in range(column_last_index.size(0))
],
dim=0,
)
encoded_output_column = encoded_output_column.view([B, C_L, encoded_D])
return encoded_output_column
[docs]def get_column_lengths(column_embed, column_name_mask):
_, _, N_L, embed_D = list(column_embed.size())
column_reshape = [-1] + [N_L, embed_D]
return torch.sum(column_name_mask.view(*column_reshape[:-1]), dim=-1).long()
[docs]def filter_used_column(encoded_columns, col_idx, padding_count=4):
B, C_L, D = list(encoded_columns.size())
zero_padding = torch.zeros(D)
if torch.cuda.is_available():
zero_padding = zero_padding.cuda(torch.cuda.current_device())
encoded_used_columns = []
for i in range(B):
encoded_used_column = torch.stack(
[encoded_columns[i][j] for j in col_idx[i]]
+ [zero_padding] * (padding_count - len(col_idx[i]))
)
encoded_used_columns.append(encoded_used_column)
return torch.stack(encoded_used_columns)