Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat oneflow as backend of torch #10205

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open

Feat oneflow as backend of torch #10205

wants to merge 33 commits into from

Conversation

strint
Copy link
Contributor

@strint strint commented Apr 26, 2023

oneflow as a backend of torch compile.

print("==>inputs ", inputs)
with flow.mock_torch.enable(lazy=True):
import torch
outs = gm.forward(inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是会走到 torch

File "<eval_with_key>.1", line 5, in forward
    relu = torch.relu(x);  x = None
TypeError: relu(): argument 'input' (position 1) must be Tensor, not Tensor

Copy link
Contributor Author

@strint strint Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果绕不过去,可能得借助 fx 来改写而不是 mock

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方是不是因为gm并不在mock_torch的作用域类所以失效了

@strint strint changed the title Feat graph interplay [Draft]Feat graph interplay Apr 27, 2023
@strint strint marked this pull request as ready for review April 27, 2023 06:48
@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Speed stats:

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Speed stats:

def _test_relu(test_case, device):
from typing import List
import torch
from oneflow.utils.backend.torch_compile import register_oneflowc
Copy link
Contributor Author

@strint strint Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注册 oneflowc

x = torch.tensor(input_arr, device=device)
eager_out = torch.relu(x)

@torch.compile(backend="oneflowc")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 oneflowc

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.


@register_backend
@fake_tensor_unsupported
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oneflowc的作为后端

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Speed stats:

# The target attribute is the function
# that call_function calls.
if node.target == torch.relu:
node.target = oneflow.relu
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是最早的思路,通过 torch python api -> torch fx -> oneflow python api

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得这个方法挺好的啊,这里都不用这样写吧。
可以写着这样通用的?

getattr(oneflow, ".".join(node.target.split(".")[1::]))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦不对,要把function对象取得名字,应该有办法吧,node这个对象也会有字符串的field吧

)
x = torch.tensor(input_arr, device=device)
torch.nn.init.constant_(linear.weight, 2.3)
eager_out = linear(x)
Copy link
Contributor Author

@strint strint Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Speed stats:

@github-actions
Copy link
Contributor

Speed stats:

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Speed stats:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants