Skip to content

Commit

Permalink
0.3.2 (#64)
Browse files Browse the repository at this point in the history
* fix finalize features floods CPU memory: remove log and load mapped gradients to device  (#59)

* torch load gradients to device instead of CPU during finalize

loading data directly on device instead of moving from CPU to GPU in score computation steps

* Remove logging of raw tensors

* Remove dependency of proj_matrix during scoring (#61)

finalize feature deletes the proj_matrix, but basic projector assumes it exists. remove dependency.

* fix grads type check in iterative gradient computers

Co-authored-by: TheaperDeng <junweid2.illinois.edu>

* [added feature] regularization term for inv(xtx) calculation  (#63)

* Updated score_computers.py for lambda_reg

* Updated traker.py to include a lambda_reg term in arguments

* minor fixes

---------


Co-authored-by: Junwei Deng < junweid2@illinois.edu>
Co-authored-by: Jiadong Guo <jaedon.guo@gmail.com>
Co-authored-by: heale04 <136350745+heale04@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 17, 2024
1 parent 39bf22a commit 96a75b6
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 25 deletions.
26 changes: 25 additions & 1 deletion tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,30 @@ def test_custom_model_output(tmp_path, cpu_proj):
)


def test_iterative_gradient_computer(tmp_path, cpu_proj):
from trak.gradient_computers import IterativeGradientComputer
from trak.projectors import NoOpProjector

model = resnet18()
N = 5
batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,))
traker = TRAKer(
model=model,
task="iterative_image_classification",
save_dir=tmp_path,
train_set_size=N,
logging_level=logging.DEBUG,
device="cpu",
use_half_precision=False,
projector=NoOpProjector(),
proj_dim=0,
gradient_computer=IterativeGradientComputer,
)
ckpt = model.state_dict()
traker.load_checkpoint(ckpt, model_id=0)
traker.featurize(batch, num_samples=N)


def test_grad_wrt_last_layer(tmp_path):
model = resnet18().eval()
N = 5
Expand Down Expand Up @@ -402,7 +426,7 @@ def test_grad_wrt_last_layer(tmp_path):
def test_grad_wrt_last_layer_cuda(tmp_path):
model = resnet18().cuda().eval()
N = 5
batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda()
batch = ch.randn(N, 3, 4, 4).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda()
traker = TRAKer(
model=model,
task="image_classification",
Expand Down
17 changes: 12 additions & 5 deletions trak/gradient_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch of data
Returns:
Tensor:
gradients of the model output function of each sample in the
batch with respect to the model's parameters.
dict[Tensor]:
A dictionary where each key is a parameter name and the value is
the gradient tensor for that parameter.
"""
# taking the gradient wrt weights (second argument of get_output, hence argnums=1)
Expand Down Expand Up @@ -183,6 +183,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(
self.model, self.func_weights, self.func_buffers, batch
Expand Down Expand Up @@ -229,7 +232,7 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch_size = batch[0].shape[0]
grads = ch.zeros(batch_size, self.grad_dim).to(batch[0].device)

margin = self.modelout_fn.get_output(self.model, *batch)
margin = self.modelout_fn.get_output(self.model, None, None, *batch)
for ind in range(batch_size):
grads[ind] = parameters_to_vector(
ch.autograd.grad(margin[ind], self.model_params, retain_graph=True)
Expand All @@ -254,5 +257,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(self.model, batch)
return self.modelout_fn.get_out_to_loss_grad(self.model, None, None, batch)
98 changes: 98 additions & 0 deletions trak/modelout_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- :class:`.ImageClassificationModelOutput`
- :class:`.CLIPModelOutput`
- :class:`.TextClassificationModelOutput`
- :class:`.IterativeImageClassificationModelOutput`
These classes implement methods that transform input batches to the desired
model output (e.g. logits, loss, etc). See Sections 2 & 3 of `our paper
Expand Down Expand Up @@ -444,8 +445,105 @@ def get_out_to_loss_grad(
return (1 - ps).clone().detach().unsqueeze(-1)


class IterativeImageClassificationModelOutput(AbstractModelOutput):
"""Margin for (multiclass) image classification. See Section 3.3 of `our
paper <https://arxiv.org/abs/2303.14186>`_ for more details.
"""

def __init__(self, temperature: float = 1.0) -> None:
"""
Args:
temperature (float, optional): Temperature to use inside the
softmax for the out-to-loss function. Defaults to 1.
"""
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.loss_temperature = temperature

@staticmethod
def get_output(
model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
images: Tensor,
labels: Tensor,
) -> Tensor:
"""For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`,
let :math:`p(z, \\theta)` be the softmax probability of the correct class.
This method implements the model output function
.. math::
\\log(\\frac{p(z, \\theta)}{1 - p(z, \\theta)}).
It uses functional models from torch.func (previously functorch) to make
the per-sample gradient computations (much) faster. For more details on
what functional models are, and how to use them, please refer to
https://pytorch.org/docs/stable/func.html and
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights (added se we don't break abstraction)
buffers (Iterable[Tensor]):
functorch model buffers (added se we don't break abstraction)
images (Tensor):
input images
labels (Tensor):
input labels
Returns:
Tensor:
model output for the given image-label pair :math:`z`
"""
logits = model(images)
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]

cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, labels] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins

def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
"""Computes the (reweighting term Q in the paper)
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
out-to-loss (reweighting term) for the input batch
"""
images, labels = batch
logits = model(images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)


TASK_TO_MODELOUT = {
"image_classification": ImageClassificationModelOutput,
"clip": CLIPModelOutput,
"text_classification": TextClassificationModelOutput,
"iterative_image_classification": IterativeImageClassificationModelOutput,
}
19 changes: 11 additions & 8 deletions trak/projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def project(self, grads: Tensor, model_id: int) -> Tensor:
Returns:
Tensor: the (non-)projected gradients
"""
return vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)
return grads

def free_memory(self):
"""A no-op method."""
Expand Down Expand Up @@ -190,7 +192,9 @@ def generate_sketch_matrix(self):
raise KeyError(f"Projection type {self.proj_type} not recognized.")

def project(self, grads: Tensor, model_id: int) -> Tensor:
grads = vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)

grads = grads.to(dtype=self.dtype)
if model_id != self.model_id:
self.model_id = model_id
Expand Down Expand Up @@ -254,7 +258,7 @@ def free_memory(self):
def get_generator_states(self):
self.generator_states = []
self.seeds = []
self.jl_size = self.proj_matrix.numel()
self.jl_size = self.grad_dim * self.block_size

for i in range(self.num_blocks):
s = self.seed + int(1e3) * i + int(1e5) * self.model_id
Expand Down Expand Up @@ -283,7 +287,8 @@ def generate_sketch_matrix(self, generator_state):
raise KeyError(f"Projection type {self.proj_type} not recognized.")

def project(self, grads: Tensor, model_id: int) -> Tensor:
grads = vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)
grads = grads.to(dtype=self.dtype)
sketch = ch.zeros(
size=(grads.size(0), self.proj_dim), dtype=self.dtype, device=self.device
Expand Down Expand Up @@ -380,10 +385,10 @@ def project(
self,
grads: Union[dict, Tensor],
model_id: int,
is_grads_dict: bool = True,
) -> Tensor:
if is_grads_dict:
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)

batch_size = grads.shape[0]

effective_batch_size = 32
Expand Down Expand Up @@ -486,7 +491,6 @@ def project(self, grads, model_id):
self.projector_per_chunk[projector_index].project(
self.ch_input[:, :pointer].contiguous(),
model_id=model_id,
is_grads_dict=False,
)
)
# reset counter
Expand All @@ -506,7 +510,6 @@ def project(self, grads, model_id):
self.projector_per_chunk[projector_index].project(
self.ch_input[:actual_bs, :pointer].contiguous(),
model_id=model_id,
is_grads_dict=False,
)
)

Expand Down
10 changes: 9 additions & 1 deletion trak/score_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
device: torch.device,
CUDA_MAX_DIM_SIZE: int = 20_000,
logging_level=logging.INFO,
lambda_reg: float = 0.0,
) -> None:
"""
Args:
Expand All @@ -132,11 +133,14 @@ def __init__(
Size of block for block-wise matmuls. Defaults to 100_000.
logging_level (logging level, optional):
Logging level for the logger. Defaults to logging.info.
lambda_reg (int):
regularization term for l2 reg on xtx
"""
super().__init__(dtype, device)
self.CUDA_MAX_DIM_SIZE = CUDA_MAX_DIM_SIZE
self.logger = logging.getLogger("ScoreComputer")
self.logger.setLevel(logging_level)
self.lambda_reg = lambda_reg

def get_xtx(self, grads: Tensor) -> Tensor:
self.proj_dim = grads.shape[1]
Expand All @@ -152,7 +156,11 @@ def get_xtx(self, grads: Tensor) -> Tensor:

def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor:
blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0)
xtx_inv = ch.linalg.inv(xtx.to(ch.float32))

xtx_reg = xtx + self.lambda_reg * torch.eye(
xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype
)
xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32))

# center X^TX inverse a bit to avoid numerical issues when going to float16
xtx_inv /= xtx_inv.abs().mean()
Expand Down
24 changes: 14 additions & 10 deletions trak/traker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
proj_max_batch_size: int = 32,
projector_seed: int = 0,
grad_wrt: Optional[Iterable[str]] = None,
lambda_reg: float = 0.0,
) -> None:
"""
Expand All @@ -72,11 +73,9 @@ def __init__(
model to use for TRAK
task (Union[AbstractModelOutput, str]):
Type of model that TRAK will be ran on. Accepts either one of
the following strings:
- :code:`image_classification`
- :code:`text_classification`
- :code:`clip`
or an instance of some implementation of the abstract class
the following strings: 1) :code:`image_classification` 2)
:code:`text_classification` 3) :code:`clip` or an instance of
some implementation of the abstract class
:class:`.AbstractModelOutput`.
train_set_size (int):
Size of the train set that TRAK is featurizing
Expand Down Expand Up @@ -129,7 +128,10 @@ def __init__(
as they appear in the model's state dictionary. If None,
gradients are taken with respect to all model parameters.
Defaults to None.
lambda_reg (float):
The :math:`\ell_2` (ridge) regularization penalty added to the
:math:`XTX` term in score computers when computing the matrix
inverse :math:`(XTX)^{-1}`. Defaults to 0.
"""

self.model = model
Expand All @@ -138,6 +140,7 @@ def __init__(
self.device = device
self.dtype = ch.float16 if use_half_precision else ch.float32
self.grad_wrt = grad_wrt
self.lambda_reg = lambda_reg

logging.basicConfig()
self.logger = logging.getLogger("TRAK")
Expand Down Expand Up @@ -183,7 +186,10 @@ def __init__(
if score_computer is None:
score_computer = BasicScoreComputer
self.score_computer = score_computer(
dtype=self.dtype, device=self.device, logging_level=logging_level
dtype=self.dtype,
device=self.device,
logging_level=logging_level,
lambda_reg=self.lambda_reg,
)

metadata = {
Expand Down Expand Up @@ -473,12 +479,10 @@ def finalize_features(

self.saver.load_current_store(model_id)

g = ch.as_tensor(self.saver.current_store["grads"])
g = ch.as_tensor(self.saver.current_store["grads"], device=self.device)
xtx = self.score_computer.get_xtx(g)
self.logger.debug(f"XTX is {xtx}")

features = self.score_computer.get_x_xtx_inv(g, xtx)
self.logger.debug(f"Features are {features}")
self.saver.current_store["features"][:] = features.to(self.dtype).cpu()
if del_grads:
self.saver.del_grads(model_id)
Expand Down

0 comments on commit 96a75b6

Please sign in to comment.