Skip to content

Commit

Permalink
Patching module. Remove custom engine HF cache. alteration module tha…
Browse files Browse the repository at this point in the history
…t has default alterations for certain model_types
  • Loading branch information
JadenFiotto-Kaufman committed Sep 27, 2023
1 parent f273f9a commit b717b76
Show file tree
Hide file tree
Showing 35 changed files with 255 additions and 100,570 deletions.
36 changes: 24 additions & 12 deletions engine/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers.generation.utils import GenerateOutput

from . import CONFIG
from .alteration import MODEL_TYPE_TO_ALTERATION
from .contexts.Generator import Generator
from .editing.Editor import Edit, Editor
from .editing.GraphEdit import GraphEdit
Expand All @@ -25,6 +26,7 @@
from .intervention import intervene
from .logger import logger
from .Module import Module
from .patching import Patcher


class Model:
Expand All @@ -41,9 +43,10 @@ class Model:
edits (List[Edit]): desc
"""

def __init__(self, model_name_or_path: str) -> None:
def __init__(self, model_name_or_path: str, alter=True) -> None:
self.model_name_or_path = model_name_or_path
self.edits: List[Edit] = list()
self.alter = alter

# Use init_empty_weights to create graph i.e the specified model with no loaded parameters,
# to use for finding shapes of Module inputs and outputs, as well as replacing torch.nn.Module
Expand All @@ -53,20 +56,26 @@ def __init__(self, model_name_or_path: str) -> None:

with accelerate.init_empty_weights(include_buffers=True):
self.config = AutoConfig.from_pretrained(
self.model_name_or_path, cache_dir=CONFIG.APP.MODEL_CACHE_PATH
self.model_name_or_path
)

self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
config=self.config,
padding_side="left",
cache_dir=CONFIG.APP.MODEL_CACHE_PATH,
trust_remote_code=True,
)
self.tokenizer.pad_token = self.tokenizer.eos_token

self.meta_model: PreTrainedModel = Module.wrap(
AutoModelForCausalLM.from_config(self.config)
)
# Check for alterations for the model type and if so perform the patch
with MODEL_TYPE_TO_ALTERATION[
self.config.model_type
] if self.alter and self.config.model_type in MODEL_TYPE_TO_ALTERATION else Patcher():
self.meta_model: PreTrainedModel = Module.wrap(
AutoModelForCausalLM.from_config(
self.config, trust_remote_code=True
)
)

for name, module in self.meta_model.named_children():
# Wrap all modules in our Module class.
Expand Down Expand Up @@ -208,12 +217,15 @@ def dispatch(self, device_map="auto") -> None:
if self.local_model is None:
logger.debug(f"Dispatching `{self.model_name_or_path}`...")

self.local_model = AutoModelForCausalLM.from_pretrained(
self.model_name_or_path,
config=self.config,
device_map=device_map,
cache_dir=CONFIG.APP.MODEL_CACHE_PATH,
)
with MODEL_TYPE_TO_ALTERATION[
self.config.model_type
] if self.alter and self.config.model_type in MODEL_TYPE_TO_ALTERATION else Patcher():
self.local_model = AutoModelForCausalLM.from_pretrained(
self.model_name_or_path,
config=self.config,
device_map=device_map,
trust_remote_code=True,
)

logger.debug(f"Dispatched `{self.model_name_or_path}`")
else:
Expand Down
5 changes: 1 addition & 4 deletions engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import yaml

from .modeling.Config import ConfigModel
from .monkey_patching import *
from .patching import *

PATH = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(PATH, "config.yaml"), "r") as file:
CONFIG = ConfigModel(**yaml.safe_load(file))

CONFIG.APP.MODEL_CACHE_PATH = os.path.join(PATH, CONFIG.APP.MODEL_CACHE_PATH)


from .Model import Model
3 changes: 3 additions & 0 deletions engine/alteration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gpt import GPT2Patcher

MODEL_TYPE_TO_ALTERATION = {"gpt2": GPT2Patcher}
90 changes: 90 additions & 0 deletions engine/alteration/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import Optional, Tuple, Union

import torch
from transformers.models import gpt2

from .. import util
from ..patching import Patch, Patcher


class GPT2AttentionAltered(gpt2.modeling_gpt2.GPT2Attention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention, layer_idx)

self.query = util.WrapperModule()
self.key = util.WrapperModule()
self.value = util.WrapperModule()

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(
self.split_size, dim=2
)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Altered -------------

query = self.query(query)
key = self.key(key)
value = self.value(value)

# ---------------------

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask
)
else:
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)


GPT2Patcher = Patcher([
Patch(gpt2.modeling_gpt2.GPT2Attention, GPT2AttentionAltered)
])
4 changes: 1 addition & 3 deletions engine/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
API:
HOST: nagoya.research.khoury.northeastern.edu:5000
APP:
MODEL_CACHE_PATH: ./model_checkpoints
HOST: localhost:5000
12 changes: 6 additions & 6 deletions engine/editing/Editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
class Edit:

@abstractmethod
def edit(self, model: torch.nn.Module):
def edit(self, obj: torch.nn.Module):
pass

@abstractmethod
def restore(self, model: torch.nn.Module):
def restore(self, obj: torch.nn.Module):
pass


class Editor:
def __init__(self, model: torch.nn.Module, edits: List[Edit]) -> None:
self.model = model
def __init__(self, obj: object, edits: List[Edit]) -> None:
self.obj = obj
self.edits = edits

def __enter__(self) -> Editor:
for edit in self.edits:
edit.edit(self.model)
edit.edit(self.obj)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
for edit in self.edits:
edit.restore(self.model)
edit.restore(self.obj)
11 changes: 1 addition & 10 deletions engine/editing/WrapperModuleEdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from .. import util
from .Editor import Edit

class WrapperModule(torch.nn.Module):

def forward(self, *args, **kwargs):

if len(args) == 1:
args = args[0]

return args


class WrapperModuleEdit(Edit):
def __init__(self, module_path: str, module_name: str) -> None:
Expand All @@ -20,7 +11,7 @@ def __init__(self, module_path: str, module_name: str) -> None:
self.module_path = module_path
self.module_name = module_name

self.wrapper = WrapperModule()
self.wrapper = util.WrapperModule()

def edit(self, model: torch.nn.Module):
module: torch.nn.Module = util.fetch_attr(model, self.module_path)
Expand Down
10 changes: 5 additions & 5 deletions engine/fx/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from .. import util
from .Node import Node
from .Patcher import Patcher
from .Proxy import Proxy
from ..patching import Patcher, Patch
from .Proxy import Proxy, proxy_wrapper


class Graph:
Expand Down Expand Up @@ -90,9 +90,9 @@ def get_argument_value(param: inspect.Parameter, idx: int):
# Some methods cannot be caught because they arent torch functions or dont play nice with __torch_function__.
# So the patcher repalces the methods with something to catch proxies and return proxies.
with Patcher() as patcher:
patcher.patch(torch.full)
patcher.patch(torch.finfo)
patcher.patch(torch.arange)
patcher.add(Patch(torch.full, proxy_wrapper(torch.full)))
patcher.add(Patch(torch.finfo, proxy_wrapper(torch.finfo)))
patcher.add(Patch(torch.arange, proxy_wrapper(torch.arange)))

# Run forward with root module proxy and arguments
output: Proxy = forward(graph.module_proxy, *arguments)
Expand Down
49 changes: 0 additions & 49 deletions engine/fx/Patcher.py

This file was deleted.

33 changes: 32 additions & 1 deletion engine/fx/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __getitem__(self, key: Union[Proxy, Any]) -> Proxy:
)

def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None:

item_proxy = self[key]

update = item_proxy.node.__class__.update
Expand Down Expand Up @@ -177,3 +176,35 @@ def __torch_function__(cls, orig_method, types, args=None, kwargs=None) -> Proxy
args=args,
kwargs=kwargs,
)


from functools import wraps


def proxy_wrapper(fn) -> None:
@wraps(fn)
def patched(*args, **kwargs):
arguments = list(args) + list(kwargs.values())

node = None

for arg in arguments:
if isinstance(arg, Proxy):
node = arg.node

break

if node is not None:
value = fn(
*node.prepare_proxy_values(args),
**node.prepare_proxy_values(kwargs),
)

return node.graph.add(
graph=node.graph, value=value, target=fn, args=args, kwargs=kwargs
)

else:
return fn(*args, **kwargs)

return patched
Empty file.
Empty file.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 0 additions & 1 deletion engine/model_checkpoints/models--gpt2/refs/main

This file was deleted.

Loading

0 comments on commit b717b76

Please sign in to comment.