Skip to content

Commit

Permalink
Update pt params converter (#2989)
Browse files Browse the repository at this point in the history
* update pt params converter

* use exclude_vars

* print warning

* add return value
  • Loading branch information
holgerroth authored Oct 4, 2024
1 parent 5ba432d commit 85a7e35
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
2 changes: 2 additions & 0 deletions nvflare/app_common/abstract/params_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractmethod
from typing import Any, List

Expand All @@ -23,6 +24,7 @@
class ParamsConverter(ABC):
def __init__(self, supported_tasks: List[str] = None):
self.supported_tasks = supported_tasks
self.logger = logging.getLogger(self.__class__.__name__)

def process(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
if not self.supported_tasks or task_name in self.supported_tasks:
Expand Down
34 changes: 30 additions & 4 deletions nvflare/app_opt/pt/params_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,42 @@
class NumpyToPTParamsConverter(ParamsConverter):
def convert(self, params: Dict, fl_ctx) -> Dict:
tensor_shapes = fl_ctx.get_prop("tensor_shapes")
exclude_vars = fl_ctx.get_prop("exclude_vars")

return_params = {}
if tensor_shapes:
return {
return_params = {
k: torch.as_tensor(np.reshape(v, tensor_shapes[k])) if k in tensor_shapes else torch.as_tensor(v)
for k, v in params.items()
}
else:
return {k: torch.as_tensor(v) for k, v in params.items()}
return_params = {k: torch.as_tensor(v) for k, v in params.items()}

if exclude_vars:
for k, v in exclude_vars.items():
return_params[k] = v

return return_params


class PTToNumpyParamsConverter(ParamsConverter):
def convert(self, params: Dict, fl_ctx) -> Dict:
fl_ctx.set_prop("tensor_shapes", {k: v.shape for k, v in params.items()})
return {k: v.cpu().numpy() for k, v in params.items()}
return_tensors = {}
tensor_shapes = {}
exclude_vars = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
return_tensors[k] = v.cpu().numpy()
tensor_shapes[k] = v.shape
else:
exclude_vars[k] = v

if tensor_shapes:
fl_ctx.set_prop("tensor_shapes", tensor_shapes)
if exclude_vars:
fl_ctx.set_prop("exclude_vars", exclude_vars)
self.logger.warning(
f"{len(exclude_vars)} vars excluded as they were non-tensor type: " f"{list(exclude_vars.keys())}"
)

return return_tensors

0 comments on commit 85a7e35

Please sign in to comment.