Skip to content

Commit

Permalink
Add args to Model(...) to be used during dispatch (including device_m…
Browse files Browse the repository at this point in the history
…ap). Genrator now has kwars 'server' instead. No need to specify device in each generate call
  • Loading branch information
JadenFiotto-Kaufman committed Sep 27, 2023
1 parent a21cca8 commit 62fb784
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 60 deletions.
21 changes: 7 additions & 14 deletions engine/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ class Model:
edits (List[Edit]): desc
"""

def __init__(self, model_name_or_path: str, alter=True) -> None:
def __init__(self, model_name_or_path: str, *args, alter=True, **kwargs) -> None:
self.model_name_or_path = model_name_or_path
self.edits: List[Edit] = list()
self.alter = alter
self.args = args
self.kwargs = kwargs

self.kwargs['trust_remote_code'] = True

# Use init_empty_weights to create graph i.e the specified model with no loaded parameters,
# to use for finding shapes of Module inputs and outputs, as well as replacing torch.nn.Module
Expand Down Expand Up @@ -222,22 +226,11 @@ def dispatch(self, device_map="auto") -> None:
] if self.alter and self.config.model_type in MODEL_TYPE_TO_ALTERATION else Patcher():
self.local_model = AutoModelForCausalLM.from_pretrained(
self.model_name_or_path,
*self.args,
config=self.config,
device_map=device_map,
trust_remote_code=True,
**self.kwargs
)

logger.debug(f"Dispatched `{self.model_name_or_path}`")
else:
if isinstance(device_map, str) and device_map != "auto":
# TODO
# self.local_model = accelerate.dispatch_model(
# self.local_model, device_map
# )

pass
else:
self.local_model.to(device_map)

def modulize(self, module: Module, node_name: str, module_name: str) -> None:
"""_summary_
Expand Down
9 changes: 4 additions & 5 deletions engine/contexts/Generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class Generator:
Attributes:
model (Model): Model object this is a generator for.
device_map (Union[str,Dict]): What device/device map to run the model on. Defaults to 'server'
blocking (bool): If when using device_map='server', block and wait form responses. Otherwise have to manually
request a response.
args (List[Any]): Arguments for calling the model.
Expand All @@ -36,12 +35,12 @@ def __init__(
self,
model: "Model",
*args,
device_map: Union[str, Dict] = "server",
blocking: bool = True,
server: bool = False,
**kwargs,
) -> None:
self.model = model
self.device_map = device_map
self.server = server
self.blocking = blocking
self.args = args
self.kwargs = kwargs
Expand All @@ -62,14 +61,14 @@ def __enter__(self) -> Generator:

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""On exit, run and generate using the model whether locally or on the server."""
if self.device_map == "server":
if self.server:
self.run_server()
else:
self.run_local()

def run_local(self):
# Dispatch the model to the correct device.
self.model.dispatch(device_map=self.device_map)
self.model.dispatch()

# Run the model and store the output.
self.output = self.model(self.prompts, self.graph, *self.args, **self.kwargs)
Expand Down
2 changes: 1 addition & 1 deletion engine/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def pytest_generate_tests(metafunc):
# if the argument is specified in the list of test "fixturenames".
option_value = metafunc.config.option.device
if 'device' in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("device", [option_value])
metafunc.parametrize("device", [option_value], scope='module')
30 changes: 15 additions & 15 deletions engine/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
import torch

@pytest.fixture(scope='module')
def gpt2():
return engine.Model('gpt2')
def gpt2(device:str):
return engine.Model('gpt2', device_map=device)

@pytest.fixture
def MSG_prompt():
return "Madison Square Garden is located in the city of"

def test_generation(gpt2:engine.Model, device:str, MSG_prompt:str):
def test_generation(gpt2:engine.Model, MSG_prompt:str):

with gpt2.generate(device_map=device, max_new_tokens=3) as generator:
with gpt2.generate(max_new_tokens=3) as generator:
with generator.invoke(MSG_prompt) as invoker:
pass

output = gpt2.tokenizer.decode(generator.output[0])

assert output == "Madison Square Garden is located in the city of New York City"

def test_save(gpt2:engine.Model, device:str):
def test_save(gpt2:engine.Model):

with gpt2.generate(device_map=device, max_new_tokens=1) as generator:
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke('Hello world') as invoker:

hs = gpt2.transformer.h[-1].output[0].save()
Expand All @@ -32,9 +32,9 @@ def test_save(gpt2:engine.Model, device:str):
assert hs.value.ndim == 3


def test_set(gpt2:engine.Model, device:str):
def test_set(gpt2:engine.Model):

with gpt2.generate(device_map=device, max_new_tokens=1) as generator:
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke('Hello world') as invoker:

pre = gpt2.transformer.h[-1].output[0].save()
Expand All @@ -49,9 +49,9 @@ def test_set(gpt2:engine.Model, device:str):
assert (post.value == 0).all().item()
assert output != "Madison Square Garden is located in the city of New"

def test_adhoc_module(gpt2:engine.Model, device:str):
def test_adhoc_module(gpt2:engine.Model):

with gpt2.generate(device_map=device) as generator:
with gpt2.generate() as generator:
with generator.invoke('The Eiffel Tower is in the city of') as invoker:

hidden_states = gpt2.transformer.h[-1].output[0]
Expand All @@ -63,9 +63,9 @@ def test_adhoc_module(gpt2:engine.Model, device:str):

assert output == "\n-el Tower is a the middle centre Paris"

def test_embeddings_set1(gpt2:engine.Model, device:str, MSG_prompt:str):
def test_embeddings_set1(gpt2:engine.Model,MSG_prompt:str):

with gpt2.generate(device_map=device, max_new_tokens=3) as generator:
with gpt2.generate( max_new_tokens=3) as generator:

with generator.invoke(MSG_prompt) as invoker:

Expand All @@ -81,17 +81,17 @@ def test_embeddings_set1(gpt2:engine.Model, device:str, MSG_prompt:str):
assert output1 == "Madison Square Garden is located in the city of New York City"
assert output2 == "_ _ _ _ _ _ _ _ _ New York City"

def test_embeddings_set2(gpt2:engine.Model, device:str, MSG_prompt:str):
def test_embeddings_set2(gpt2:engine.Model, MSG_prompt:str):

with gpt2.generate(device_map=device, max_new_tokens=3) as generator:
with gpt2.generate(max_new_tokens=3) as generator:

with generator.invoke(MSG_prompt) as invoker:

embeddings = gpt2.transformer.wte.output.save()

output1 = gpt2.tokenizer.decode(generator.output[0])

with gpt2.generate(device_map=device, max_new_tokens=3) as generator:
with gpt2.generate(max_new_tokens=3) as generator:

with generator.invoke("_ _ _ _ _ _ _ _ _") as invoker:

Expand Down
4 changes: 2 additions & 2 deletions examples/adhoc_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from engine import Model
import torch

model = Model("gpt2")
model = Model("gpt2", device_map='cuda:0')

with model.generate(device_map='cuda:0') as generator:
with model.generate() as generator:
with generator.invoke('The Eiffel Tower is in the city of') as invoker:

hidden_states = model.transformer.h[-1].output[0]
Expand Down
4 changes: 2 additions & 2 deletions examples/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from engine import Model

model = Model('gpt2')
model = Model('gpt2', device_map='cuda:0')

print(model)

with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
with model.generate(max_new_tokens=3) as generator:

with generator.invoke("Madison square garden is located in the city of New") as invoker:

Expand Down
4 changes: 2 additions & 2 deletions examples/modulize.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from engine import Model

model = Model('gpt2')
model = Model('gpt2', device_map='cuda:0')

print(model.transformer.h[1].attn.graph)

model.modulize(model.transformer.h[1].attn, 'softmax_0', 'attention_probs')

with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
with model.generate(max_new_tokens=3) as generator:

with generator.invoke('Hello world') as invoker:

Expand Down
3 changes: 1 addition & 2 deletions examples/multitoken.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from engine import Model

model = Model("gpt2")
model = Model("gpt2", device_map='cuda:0')


def get_scores():
Expand All @@ -15,7 +15,6 @@ def decode(scores):


with model.generate(
device_map="cuda:0",
max_new_tokens=3,
return_dict_in_generate=True,
output_scores=True,
Expand Down
4 changes: 2 additions & 2 deletions examples/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from engine import Model
import torch
# Get model wrapper for any model you can get with AutoConfig.from_pretrained(model_name)
model = Model('gpt2')
model = Model('gpt2',device_map='cuda:0')

# Prints normal module tree to show access tree for modules
print(model)

# Invoke using a prompt
with model.generate(device_map='cuda:0', max_new_tokens=3) as generator:
with model.generate(max_new_tokens=3) as generator:
with generator.invoke('Hello world ') as invoker:

# See the input prompt seperated into token strings
Expand Down
2 changes: 1 addition & 1 deletion examples/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

print(model)

with model.generate(device_map="server") as generator:
with model.generate(server=True) as generator:
with generator.invoke("Hello world") as invoker:
hiddenstates = model.transformer.h[2].output.save()

Expand Down
2 changes: 1 addition & 1 deletion examples/test_server_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

print(model)

with model.generate(device_map='server', max_new_tokens=1, return_dict_in_generate=True, output_scores=True) as generator:
with model.generate(server=True, max_new_tokens=1, return_dict_in_generate=True, output_scores=True) as generator:
with generator.invoke('Hello world') as invoker:

hiddenstates = model.model.layers[0].output.save()
Expand Down
14 changes: 7 additions & 7 deletions examples/tl/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
},
"outputs": [],
"source": [
"model = engine.Model('gpt2')"
"model = engine.Model('gpt2', device_map='cuda:0')"
]
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@
"\n",
"For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!\"\"\"\n",
"\n",
"with model.generate(device_map='cuda:0', max_new_tokens=1) as generator:\n",
"with model.generate(max_new_tokens=1) as generator:\n",
"\n",
" with generator.invoke(model_description_text) as invoker:\n",
"\n",
Expand Down Expand Up @@ -397,7 +397,7 @@
"source": [
"gpt2_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n",
"\n",
"with model.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:\n",
"with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n",
"\n",
" with generator.invoke(gpt2_text, output_attentions=True) as invoker:\n",
"\n",
Expand Down Expand Up @@ -704,7 +704,7 @@
"correct_index = model.tokenizer(\" John\")['input_ids'][0]\n",
"incorrect_index = model.tokenizer(\" Mary\")['input_ids'][0]\n",
"\n",
"with model.generate(device_map='cuda:0', max_new_tokens=1) as generator:\n",
"with model.generate(max_new_tokens=1) as generator:\n",
"\n",
" with generator.invoke(clean_prompt) as invoker:\n",
"\n",
Expand Down Expand Up @@ -949,7 +949,7 @@
],
"source": [
"\n",
"with model.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:\n",
"with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n",
"\n",
" with generator.invoke(repeated_tokens, output_attentions=True) as invoker:\n",
"\n",
Expand Down Expand Up @@ -1076,9 +1076,9 @@
}
],
"source": [
"distilgpt2 = engine.Model('distilgpt2')\n",
"distilgpt2 = engine.Model('distilgpt2', device_map='cuda:0')\n",
"\n",
"with distilgpt2.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:\n",
"with distilgpt2.generate(max_new_tokens=1, output_attentions=True) as generator:\n",
"\n",
" with generator.invoke(repeated_tokens, output_attentions=True) as invoker:\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/tl/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
"from engine.fx.Proxy import Proxy\n",
"from engine import util\n",
"\n",
"model = Model('gpt2')\n",
"model = Model('gpt2',device_map='cuda:0')\n",
"\n",
"gpt2_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n",
"\n",
"with model.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:\n",
"with model.generate(max_new_tokens=1, output_attentions=True) as generator:\n",
"\n",
" with generator.invoke(gpt2_text, output_attentions=True) as invoker:\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/tl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from engine.fx.Proxy import Proxy
from engine import util

model = Model('gpt2')
model = Model('gpt2', device_map='cuda:0')

gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."

with model.generate(device_map='cuda:0', max_new_tokens=1, output_attentions=True) as generator:
with model.generate(max_new_tokens=1, output_attentions=True) as generator:

with generator.invoke(gpt2_text, output_attentions=True) as invoker:

Expand Down
4 changes: 2 additions & 2 deletions examples/tl/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from engine import util
import torch

model = Model('gpt2')
model = Model('gpt2', device_map='cuda:0')

clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
Expand All @@ -12,7 +12,7 @@
incorrect_index = model.tokenizer(" Mary")['input_ids'][0]


with model.generate(device_map='cuda:0', max_new_tokens=1) as generator:
with model.generate(max_new_tokens=1) as generator:

with generator.invoke(clean_prompt) as invoker:

Expand Down

0 comments on commit 62fb784

Please sign in to comment.