Skip to content

Commit

Permalink
[Paddle] Replace paddle.fluid imports with paddle.base (#633)
Browse files Browse the repository at this point in the history
* Replace paddle.fluid imports with paddle.base

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove paddle.fluid usage from tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
timmoon10 and ksivaman authored Jan 30, 2024
1 parent ef4b1d1 commit 8d3b62d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion qa/L0_paddle_lint/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ then
cp $TE_PATH/qa/L0_paddle_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/paddle
python -m pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
7 changes: 5 additions & 2 deletions tests/paddle/dist_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import time
import unittest

from paddle import fluid
try:
from paddle.base import core
except ImportError:
from paddle.fluid import core
from paddle.distributed.utils.launch_utils import (
TrainerProc,
find_free_ports,
Expand Down Expand Up @@ -114,7 +117,7 @@ def run_2gpu(
allocator_strategy="auto_growth",
):
"""Run target file in subprocesses"""
if (not fluid.core.is_compiled_with_cuda() or fluid.core.get_cuda_device_count() == 0):
if (not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0):
return

selected_gpus = get_gpus('0,1')
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/paddle/layer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
import numpy as np

import paddle
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer
try:
from paddle.base import core
from paddle.base.framework import _dygraph_tracer
except ImportError:
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer

from ..constants import FP8BwdTensors, dist_group_type
from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/paddle/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from contextlib import contextmanager

from paddle.fluid import core
try:
from paddle.base import core
except ImportError:
from paddle.fluid import core


@contextmanager
Expand Down

0 comments on commit 8d3b62d

Please sign in to comment.