Skip to content

Commit

Permalink
white and black lists for pretrained model (#1528)
Browse files Browse the repository at this point in the history
* white and black lists for pretrained model

* address comments
  • Loading branch information
hnyu authored Sep 5, 2023
1 parent c6cf608 commit 36947f5
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
39 changes: 37 additions & 2 deletions alf/pretrained_models/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import torch
import torch.nn as nn

Expand All @@ -32,6 +33,8 @@ class PretrainedModel(nn.Module):
def __init__(self,
model: nn.Module,
adapter_cls: List[Callable] = [],
module_blacklist: List[str] = None,
module_whitelist: List[str] = None,
name: str = 'PretrainedModel'):
"""
Args:
Expand All @@ -40,27 +43,59 @@ def __init__(self,
model layers. An adapter instance of each class will be created
for each of all qualified layers. The adapter weights will be
stored in the adapter instance.
module_blacklist: an optional blacklist of modules not to be adapted.
Each entry can be a regex or a substring of the module name.
module_whitelist: an optional whitelist of modules not to be adapted.
Each entry can be a regex or a substring of the module name. By
default this is None which means all modules are valid. Only at
most one of ``module_blacklist`` and ``module_whitelist`` can be
provided.
name: name of the pretrained model
"""
super().__init__()
assert not (module_blacklist and module_whitelist), (
"Blacklist and whitelist cannot be provided at the same time!")
self._name = name
# Freeze all parameters
for para in model.parameters():
para.requires_grad = False
# Use a list trick to let pytorch ignore this module for checkpointing
self._model = [model]
# This is the ONLY module that affects checkpointing
self._adapters = nn.ModuleList()
for m in self.model.modules():
self._adapted_module_names = []
for name, m in self.model.named_modules():
skip = False
if module_blacklist is not None:
for b in module_blacklist:
if re.search(b, name):
skip = True
break
elif module_whitelist is not None:
skip = True
for w in module_whitelist:
if re.search(w, name):
skip = False
break
if skip:
continue

for acls in adapter_cls:
if acls.can_adapt(m):
self._adapters.append(acls(m))
self._name = name
self._adapted_module_names.append(name)

@property
def model(self) -> nn.Module:
"""Return the base model."""
return self._model[0]

@property
def adapted_module_names(self) -> List[str]:
"""Return a list of adapted module names, in the adapter adding order.
"""
return self._adapted_module_names

def remove_adapter(self) -> nn.ModuleList:
"""Remove the adapter (if existed).
Expand Down
58 changes: 57 additions & 1 deletion alf/pretrained_models/pretrained_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@
import os
import json

import torch
import torch.nn as nn

import alf
from alf.pretrained_models.pretrained_model import PretrainedModel
import alf.utils.checkpoint_utils as ckpt_utils
from alf.utils.checkpoint_utils_test import Net
from alf.pretrained_models.model_adapters.lora import (LinearAdapter,
Conv2dAdapter)


class Net(nn.Module):
def __init__(self, size=5):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 3, padding=1)
self.conv2 = nn.Conv2d(6, 1, 3, padding=1)
self.fc1 = nn.Linear(size**2, size**2)
self.fc2 = nn.Linear(size**2, 10)
self._size = size

def forward(self, input):
return self.fc2(
self.fc1(self.conv2(self.conv1(input)).reshape(-1, self._size**2)))


class PretrainedModelTest(alf.test.TestCase):
def test_pretrained_model_ckpt(self):
with tempfile.TemporaryDirectory() as ckpt_dir:
Expand Down Expand Up @@ -61,6 +77,46 @@ def test_pretrained_model_ckpt(self):

ckpt_mngr.load(0)

def test_finetuning_grad(self):
alf.reset_configs()
alf.config('Conv2dAdapter', rank=32)
net = Net(10)
net.half()
pretrained_net = PretrainedModel(
net, adapter_cls=[LinearAdapter, Conv2dAdapter])
x = torch.zeros([1, 3, 10, 10]).to(torch.float16)
y = pretrained_net(x).sum()
y.float().backward()
for name, para in pretrained_net.named_parameters():
self.assertTrue(para.grad is not None)

def test_module_blacklist(self):
alf.reset_configs()
alf.config('Conv2dAdapter', rank=32)
alf.config('LinearAdapter', rank=32)
net = Net(10)
# This regex will exclude 'conv1 and 'fc1'
blacklist = ['.*1']
pretrained_net = PretrainedModel(
net,
adapter_cls=[LinearAdapter, Conv2dAdapter],
module_blacklist=blacklist)
for name in pretrained_net.adapted_module_names:
for b in blacklist:
assert b not in name
# because of blacklist, adapters only have two weights
self.assertEqual(len(list(pretrained_net.parameters())), 2)

pretrained_net.remove_adapter()

whitelist = ['.*conv.*']
pretrained_net = PretrainedModel(
net,
adapter_cls=[LinearAdapter, Conv2dAdapter],
module_whitelist=whitelist)
# because of whitelist, adapters only have two weights
self.assertEqual(len(list(pretrained_net.parameters())), 2)


if __name__ == '__main__':
alf.test.main()

0 comments on commit 36947f5

Please sign in to comment.