Source code for tmnt.data_loading

# coding: utf-8
# Copyright (c) 2019-2021. The MITRE Corporation.
"""
File/module contains routines for loading in text documents to sparse matrix representations
for efficient neural variational model training.
"""

import io
import itertools
import os
import logging
import scipy
import numpy as np
import string
import re
import json
from collections import OrderedDict, Counter
from sklearn.datasets import load_svmlight_file
from sklearn.utils import shuffle as sk_shuffle
from tmnt.preprocess.vectorizer import TMNTVectorizer
import random

from scipy import sparse as sp
from typing import List, Tuple, Dict, Optional, Union, NoReturn

import torch
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler
from torchtext.vocab import vocab as build_vocab
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from transformers import DistilBertTokenizer, DistilBertModel, AutoTokenizer, AutoModel, DistilBertTokenizer, BertModel, DistilBertModel, OpenAIGPTModel
from sklearn.model_selection import StratifiedKFold

#### Huggingface LLM-specific dataloading ####

llm_catalog = {
    'distilbert-base-uncased': (DistilBertTokenizer.from_pretrained, DistilBertModel.from_pretrained),
    'bert-base-uncased' : (AutoTokenizer.from_pretrained, BertModel.from_pretrained),
    'openai-gpt' : (AutoTokenizer.from_pretrained, OpenAIGPTModel.from_pretrained), 
    'sentence-transformers/all-mpnet-base-v2' : (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
    'allenai/scibert_scivocab_uncased': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
    'johngiorgi/declutr-sci-base': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained),
    'BAAI/bge-base-en-v1.5': (AutoTokenizer.from_pretrained, AutoModel.from_pretrained) 
    ## add more model options here if desired
    }

[docs]def get_llm(model_name): tok_fn, model_fn = llm_catalog[model_name] return tok_fn(model_name), model_fn(model_name)
[docs]def get_llm_tokenizer(model_name): tok_fn, _ = llm_catalog[model_name] return tok_fn(model_name)
[docs]def get_llm_model(model_name): _, model_fn = llm_catalog[model_name] return model_fn(model_name)
[docs]def get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'): label_pipeline = lambda x: label_map.get(x, 0) text_pipeline = get_llm_tokenizer(llm_name) def collate_batch(batch): label_list, text_list, mask_list, bow_list = [], [], [], [] for (_label, _text) in batch: label_list.append(label_pipeline(_label)) tokenized_result = text_pipeline(_text, return_tensors='pt', padding='max_length', max_length=max_len, truncation=True) bag_of_words,_ = bow_vectorizer.transform([_text]) processed_text = tokenized_result['input_ids'] mask = tokenized_result['attention_mask'] mask_list.append(mask) text_list.append(processed_text) bow_list.append(bag_of_words) label_list = torch.tensor(label_list, dtype=torch.int64) text_list = torch.vstack(text_list) mask_list = torch.vstack(mask_list) bow_list = torch.vstack([ sparse_coo_to_tensor(bow_vec.tocoo()) for bow_vec in bow_list ]) return label_list.to(device), text_list.to(device), mask_list.to(device), bow_list.to(device) return DataLoader(data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_batch)
[docs]def get_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=False, device='cpu'): return SingletonWrapperLoader(get_unwrapped_llm_dataloader(data, bow_vectorizer, llm_name, label_map, batch_size, max_len, shuffle=shuffle, device=device))
[docs]def get_llm_paired_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, shuffle_both=False, shuffle_a_only=True, device='cpu'): loader_a = get_unwrapped_llm_dataloader(data_a, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, shuffle=(shuffle_both or shuffle_a_only), device=device) loader_b = get_unwrapped_llm_dataloader(data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_b, shuffle=shuffle_both, device=device) return PairedDataLoader(loader_a, loader_b)
class StratifiedPairedLLMLoader(): def __init__(self, data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, num_batches=0, device='cpu'): self.data_a = data_a self.data_b = data_b self.bow_vectorizer = bow_vectorizer self.llm_name = llm_name self.label_map = label_map self.batch_size = batch_size self.max_len_a = max_len_a self.max_len_b = max_len_b self.device = device self.num_batches = num_batches or max(len(data_a), len(data_b)) // batch_size self.stratified_sampler = StratifiedDualBatchSampler(np.array([label_map[l] for (l,_) in data_a]), np.array([label_map[l] for (l,_) in data_b]), batch_size, self.num_batches) self.iterator = None self.label_pipeline = lambda x: label_map.get(x, 0) self.text_pipeline = get_llm_tokenizer(llm_name) def __iter__(self): self.iterator = iter(self.stratified_sampler) return self def __len__(self): return self.num_batches def __next__(self): indices_a, indices_b = next(self.iterator) batch_a = self._collate_batch([self.data_a[i_a] for i_a in indices_a], self.max_len_a) batch_b = self._collate_batch([self.data_b[i_b] for i_b in indices_b], self.max_len_b) return batch_a, batch_b def _collate_batch(self, batch, max_len): label_list, text_list, mask_list, bow_list = [], [], [], [] for (_label, _text) in batch: label_list.append(self.label_pipeline(_label)) tokenized_result = self.text_pipeline(_text, return_tensors='pt', padding='max_length', max_length=max_len, truncation=True) bag_of_words,_ = self.bow_vectorizer.transform([_text]) processed_text = tokenized_result['input_ids'] mask = tokenized_result['attention_mask'] mask_list.append(mask) text_list.append(processed_text) bow_list.append(bag_of_words) label_list = torch.tensor(label_list, dtype=torch.int64) text_list = torch.vstack(text_list) mask_list = torch.vstack(mask_list) bow_list = torch.vstack([ sparse_coo_to_tensor(bow_vec.tocoo()) for bow_vec in bow_list ]) return label_list.to(self.device), text_list.to(self.device), mask_list.to(self.device), bow_list.to(self.device) #def get_llm_paired_stratified_dataloader(data_a, data_b, bow_vectorizer, llm_name, label_map, batch_size, max_len_a, max_len_b, device='cpu'): # return ##############################################
[docs]def to_label_matrix(yvs, num_labels=0): """Convert [(id1, id2, ...), (id1,id2,...) ... ] to Numpy matrix with multi-labels """ if num_labels == 0: mx_val = 0 for yi in yvs: for v in yi: if v > mx_val: mx_val = v num_labels = int(mx_val + 1) li = [] for yi in yvs: a = np.zeros(num_labels) a[np.array(yi, dtype='int64')] = 1 li.append(a) return np.array(li), num_labels
[docs]class SparseDataset(): """ Custom Dataset class for scipy sparse matrix """ def __init__(self, data:Union[np.ndarray, sp.coo_matrix, sp.csr_matrix], targets: Optional[Union[np.ndarray, sp.coo_matrix, sp.csr_matrix]]): # Transform data coo_matrix to csr_matrix for indexing if type(data) == sp.coo_matrix: self.data = data.tocsr() else: self.data = data # Transform targets coo_matrix to csr_matrix for indexing if type(targets) == sp.coo_matrix: self.targets = targets.tocsr() else: self.targets = targets def __getitem__(self, index): targets_i = self.targets[index] if self.targets is not None else None return self.data[index], targets_i def __len__(self): return self.data.shape[0]
[docs]def sparse_coo_to_tensor(coo: sp.coo_matrix): """ Transform scipy coo matrix to pytorch sparse tensor """ values = coo.data indices = (coo.row, coo.col) shape = coo.shape i = torch.LongTensor(indices) v = torch.FloatTensor(values) s = torch.Size(shape) return torch.sparse.FloatTensor(i, v, s)
[docs]def sparse_batch_collate(batch): """ Collate function which to transform scipy coo matrix to pytorch sparse tensor """ # batch[0] since it is returned as a one element list data_batch, targets_batch = batch[0] if type(data_batch[0]) == sp.csr_matrix: data_batch = data_batch.tocoo() # removed vstack data_batch = sparse_coo_to_tensor(data_batch) else: data_batch = torch.FloatTensor(data_batch) if targets_batch is not None: if type(targets_batch[0]) == sp.csr_matrix: targets_batch = targets_batch.tocoo() # removed vstack targets_batch = sparse_coo_to_tensor(targets_batch) else: targets_batch = torch.LongTensor(targets_batch) return data_batch, targets_batch
[docs]class SparseDataLoader(DataLoader): def __init__(self, X: Union[sp.csr_matrix, sp.coo_matrix], y: np.array, shuffle=False, drop_last=False, batch_size=1024, device='cpu'): self.batch_size = batch_size ds = SparseDataset(X, y) sampler = torch.utils.data.sampler.BatchSampler( torch.utils.data.sampler.RandomSampler(ds, generator=torch.Generator(device=device)), batch_size=batch_size, drop_last=False) super().__init__(ds, batch_size=1, collate_fn=sparse_batch_collate, generator=torch.Generator(device=device), sampler=sampler, drop_last=drop_last)
class SingletonWrapperLoader(): def __init__(self, data_loader): self.data_loader = data_loader self.data_iter = iter(data_loader) def __iter__(self): self.data_iter = iter(self.data_loader) return self def __len__(self): return len(self.data_iter) def __next__(self): batch = self.data_iter.__next__() return (batch,) def next(self): return self.__next__() class PairedDataLoader(): def __init__(self, data_loader1, data_loader2): self.data_loader1 = data_loader1 self.data_loader2 = data_loader2 self.data_iter1 = iter(data_loader1) self.data_iter2 = iter(data_loader2) if data_loader2 is not None else None self.batch_index = 0 self.end1 = False self.end2 = False def __iter__(self): self.data_iter1 = iter(self.data_loader1) self.data_iter2 = iter(self.data_loader2) if self.data_loader2 is not None else None self.batch_index = 0 self.end1 = False self.end2 = False return self def __len__(self): if self.data_loader2 is not None: return max(len(self.data_loader1), len(self.data_loader2)) else: return len(self.data_loader1) def __next__(self): try: batch1 = self.data_iter1.__next__() except StopIteration: if self.end2 or self.data_loader2 is None: raise StopIteration self.data_iter1 = iter(self.data_loader1) self.end1 = True batch1 = self.data_iter1.__next__() if self.data_loader2 is not None: try: batch2 = self.data_iter2.__next__() except StopIteration: if self.end1: raise StopIteration self.data_iter2 = iter(self.data_loader2) self.end2 = True batch2 = self.data_iter2.__next__() else: batch2 = None return batch1, batch2 def next(self): return self.__next__() class RoundRobinDataLoader(): def __init__(self, data_loaders): self.num_loaders = len(data_loaders) self.data_loaders = data_loaders self.data_iters = [iter(d) for d in data_loaders] self.data_totals = None self.ratio_remaining = np.array([1.0 for _ in data_loaders]) def _get_iter_length(self, it): c = 0 try: while True: _ = next(it) c += 1 except: return c def _set_lengths(self, iters): self.data_totals = [ self._get_iter_length(it) for it in iters ] def __iter__(self): self._set_lengths( [iter(d) for d in self.data_loaders] ) self.ratio_remaining[:] = 1.0 self.data_iters = [iter(d) for d in self.data_loaders] return self def __len__(self): return sum([len(it) for it in self.data_iters]) def __next__(self): it_id = np.argsort(-self.ratio_remaining)[0] ## get iterator with most elements left it = self.data_iters[it_id] batch = it.__next__() self.ratio_remaining[it_id] = ((self.ratio_remaining[it_id] * self.data_totals[it_id]) - 1) / self.data_totals[it_id] return batch def next(self): return self.__next__() def _init_data(data, allow_empty, default_name): """Convert data into canonical form.""" assert (data is not None) or allow_empty if data is None: data = [] data = OrderedDict([(default_name, data)]) # pylint: disable=redefined-variable-type return list(data.items())
[docs]def load_vocab(vocab_file, encoding='utf-8'): """ Load a pre-derived vocabulary, assumes format consisting of a single word on each line. Note: this is a bit of a hack to use a counter to sort the vocab items IN THE ORDER THEY ARE FOUND IN THE FILE. """ w_dict = {} words = [] with io.open(vocab_file, 'r', encoding=encoding) as fp: for line in fp: els = line.split(' ') words.append(els[0].strip()) ln_wds = len(words) for i in range(ln_wds): w_dict[words[i]] = ln_wds - i return build_vocab(w_dict)
[docs]class StratifiedDualBatchSampler: """Stratified batch sampling Provides equal representation of target classes in each batch """ def __init__(self, y_a, y_b, batch_size, num_batches, shuffle=True): assert len(y_a.shape) == 1 # 'label array must be 1D' assert len(y_b.shape) == 1 self.y_a = y_a self.y_b = y_b self.shuffle = shuffle self.batch_size = batch_size self.num_batches = num_batches self.counts_a = Counter(y_a) self.counts_b = Counter(y_b) self.class_weights_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1) self.class_weights_b = [0] * (max(np.max(y_a), np.max(y_b)) + 1) for k in self.counts_a: self.class_weights_a[k] = self.counts_a[k] / len(y_a) for k in self.counts_b: self.class_weights_b[k] = self.counts_b[k] / len(y_b) self.class_indices_a = [0] * (max(np.max(y_a), np.max(y_b)) + 1) self.class_indices_b = [0] * (max(np.max(y_b), np.max(y_a)) + 1) for i in range(len(self.class_indices_a)): self.class_indices_a[i] = list(np.where(y_a == i)[0]) for i in range(len(self.class_indices_b)): self.class_indices_b[i] = list(np.where(y_b == i)[0]) self.a_only = self.counts_a.keys() - self.counts_b.keys() self.b_only = self.counts_b.keys() - self.counts_a.keys() self.use_with_replacement = (self.batch_size > len(self.class_weights_a)) def _pop_leave_last(self, li): if len(li) == 1: return li[0] else: return li.pop() def __iter__(self): samplers_a = [ iter(RandomSampler(self.class_indices_a[c], replacement=True, num_samples=(self.num_batches * self.batch_size))) for c in range(len(self.class_indices_a)) ] samplers_b = [ iter(RandomSampler(self.class_indices_b[c], replacement=True, num_samples=(self.num_batches * self.batch_size))) for c in range(len(self.class_indices_b)) ] for i in range(self.num_batches): if i % 2 == 0: classes_a = list(WeightedRandomSampler(self.class_weights_a, self.batch_size, replacement=self.use_with_replacement)) b_list = list(self.counts_b) random.shuffle(b_list) classes_b = [ self._pop_leave_last(b_list) if a in self.a_only else a for a in classes_a] batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a] batch_indices_b = [ self.class_indices_b[c][next(samplers_b[c])] for c in classes_b] else: classes_b = list(WeightedRandomSampler(self.class_weights_b, self.batch_size, replacement=self.use_with_replacement)) a_list = list(self.counts_a) random.shuffle(a_list) classes_a = [ self._pop_leave_last(a_list) if b in self.b_only else b for b in classes_b] batch_indices_a = [ self.class_indices_a[c][next(samplers_a[c])] for c in classes_a] batch_indices_b = [ self.class_indices_b[c][next(samplers_b[c])] for c in classes_b] yield (batch_indices_a, batch_indices_b) def __len__(self): return len(self.num_batches)