Исходный код abs_text_attack.core.interfaces.model

from textattack.models.wrappers import PyTorchModelWrapper
from textattack.models.wrappers import ModelWrapper # We have to use ModelWrapper class from text attack lib for our abc
import numpy as np
import torch

[документация] class TorchModel(ModelWrapper): """ Class designed specifically for handling torch models """ def __init__(): raise NotImplementedError() def __call__(): raise NotImplementedError()
[документация] def tokenize(): # redefines silly tokenize finction from ModelWrapper raise NotImplementedError()
[документация] class SklearnModel(ModelWrapper): """ Class designed specifically for handling sklearn models. Vectorizer should implement transform and model should implement 'predict_proba'. """ def __init__(self, model, vectorizer): self.model = model self.vectorizer = vectorizer def __call__(self, text_input_list, batch_size=None): # TODO: check data types vectorized_text_matrix = self.vectorizer.transform(text_input_list).toarray() return self.model.predict_proba(vectorized_text_matrix)
[документация] class HFModel(PyTorchModelWrapper): """ Класс-обертка для моделей с HuggingFace. На данный момент обернуты: - AutoModelForSequenceClassification. Всегда идет вместе с токенизатором, базовый класс для моделей из коробки с HF - pipeline. Токенизатор встроенный, также идет как базовый класс для моделей из коробки с HF - SetFitModel. Модели few-shot обучения. Основаны на SentenceTransformers, но образуют свой класс. """ def __init__(self, model, model_type='AutoModelForSequenceClassification', tokenizer=None): self.model = model self.tokenizer = tokenizer if not (model_type in set(['AutoModelForSequenceClassification', 'SetFitModel', 'pipeline'])): raise ValueError('Unknown type of model') self.model_type = model_type def __call__(self, text_input_list): try: model_device = next(self.model.parameters()).device except: model_device = 'cpu' self.model.to(model_device) if self.model_type == 'AutoModelForSequenceClassification': # Default max length is set to be int(1e30), so we force 512 to enable batching. max_length = ( 512 if self.tokenizer.model_max_length == int(1e30) else self.tokenizer.model_max_length ) inputs_dict = self.tokenizer( text_input_list, add_special_tokens=True, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) inputs_dict.to(model_device) with torch.no_grad(): outputs = self.model(**inputs_dict) return outputs.logits.detach() if self.model_type == 'pipeline': with torch.no_grad(): outputs = self.model(text_input_list, return_all_scores=True) for i in range(len(outputs)): outputs[i] = [item['score'].detach() for item in outputs[i]] return outputs if self.model_type == 'SetFitModel': # Проверить detach() with torch.no_grad(): if len(np.array(text_input_list).shape) == 0: outputs = self.model.predict_proba(text_input_list) else: outputs = self.model.predict_proba(text_input_list) return outputs.tolist()
# TODO: Написать нормальное вычисление градиентов для ситуации когда мы с нуля обучаем, # сейчас этот код не будет учить # We can surely use corresponding model wrappers from textattack, # but redefining them can give us more freedom # Создать фабрику моделей для динамической загрузки и инициализации разных типов моделей. # Для каждого типа модели должна быть своя логика обработки параметров.