from textattack.models.wrappers import PyTorchModelWrapper
from textattack.models.wrappers import ModelWrapper # We have to use ModelWrapper class from text attack lib for our abc
import numpy as np
import torch
[документация]
class TorchModel(ModelWrapper):
"""
Class designed specifically for handling torch models
"""
def __init__():
raise NotImplementedError()
def __call__():
raise NotImplementedError()
[документация]
def tokenize():
# redefines silly tokenize finction from ModelWrapper
raise NotImplementedError()
[документация]
class SklearnModel(ModelWrapper):
"""
Class designed specifically for handling sklearn models.
Vectorizer should implement transform and model should implement 'predict_proba'.
"""
def __init__(self, model, vectorizer):
self.model = model
self.vectorizer = vectorizer
def __call__(self, text_input_list, batch_size=None):
# TODO: check data types
vectorized_text_matrix = self.vectorizer.transform(text_input_list).toarray()
return self.model.predict_proba(vectorized_text_matrix)
[документация]
class HFModel(PyTorchModelWrapper):
"""
Класс-обертка для моделей с HuggingFace. На данный момент обернуты:
- AutoModelForSequenceClassification. Всегда идет вместе с токенизатором, базовый класс для моделей из коробки с HF
- pipeline. Токенизатор встроенный, также идет как базовый класс для моделей из коробки с HF
- SetFitModel. Модели few-shot обучения. Основаны на SentenceTransformers, но образуют свой класс.
"""
def __init__(self, model, model_type='AutoModelForSequenceClassification', tokenizer=None):
self.model = model
self.tokenizer = tokenizer
if not (model_type in set(['AutoModelForSequenceClassification', 'SetFitModel', 'pipeline'])):
raise ValueError('Unknown type of model')
self.model_type = model_type
def __call__(self, text_input_list):
try:
model_device = next(self.model.parameters()).device
except:
model_device = 'cpu'
self.model.to(model_device)
if self.model_type == 'AutoModelForSequenceClassification':
# Default max length is set to be int(1e30), so we force 512 to enable batching.
max_length = (
512
if self.tokenizer.model_max_length == int(1e30)
else self.tokenizer.model_max_length
)
inputs_dict = self.tokenizer(
text_input_list,
add_special_tokens=True,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
inputs_dict.to(model_device)
with torch.no_grad():
outputs = self.model(**inputs_dict)
return outputs.logits.detach()
if self.model_type == 'pipeline':
with torch.no_grad():
outputs = self.model(text_input_list, return_all_scores=True)
for i in range(len(outputs)):
outputs[i] = [item['score'].detach() for item in outputs[i]]
return outputs
if self.model_type == 'SetFitModel':
# Проверить detach()
with torch.no_grad():
if len(np.array(text_input_list).shape) == 0:
outputs = self.model.predict_proba(text_input_list)
else:
outputs = self.model.predict_proba(text_input_list)
return outputs.tolist()
# TODO: Написать нормальное вычисление градиентов для ситуации когда мы с нуля обучаем,
# сейчас этот код не будет учить
# We can surely use corresponding model wrappers from textattack,
# but redefining them can give us more freedom
# Создать фабрику моделей для динамической загрузки и инициализации разных типов моделей.
# Для каждого типа модели должна быть своя логика обработки параметров.