-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat oneflow as backend of torch #10205
base: master
Are you sure you want to change the base?
Conversation
print("==>inputs ", inputs) | ||
with flow.mock_torch.enable(lazy=True): | ||
import torch | ||
outs = gm.forward(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还是会走到 torch
File "<eval_with_key>.1", line 5, in forward
relu = torch.relu(x); x = None
TypeError: relu(): argument 'input' (position 1) must be Tensor, not Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果绕不过去,可能得借助 fx 来改写而不是 mock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方是不是因为gm并不在mock_torch的作用域类所以失效了
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
def _test_relu(test_case, device): | ||
from typing import List | ||
import torch | ||
from oneflow.utils.backend.torch_compile import register_oneflowc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注册 oneflowc
x = torch.tensor(input_arr, device=device) | ||
eager_out = torch.relu(x) | ||
|
||
@torch.compile(backend="oneflowc") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 oneflowc
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
|
||
@register_backend | ||
@fake_tensor_unsupported | ||
def oneflowc(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneflowc的作为后端
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
# The target attribute is the function | ||
# that call_function calls. | ||
if node.target == torch.relu: | ||
node.target = oneflow.relu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是最早的思路,通过 torch python api -> torch fx -> oneflow python api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得这个方法挺好的啊,这里都不用这样写吧。
可以写着这样通用的?
getattr(oneflow, ".".join(node.target.split(".")[1::]))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦不对,要把function对象取得名字,应该有办法吧,node这个对象也会有字符串的field吧
) | ||
x = torch.tensor(input_arr, device=device) | ||
torch.nn.init.constant_(linear.weight, 2.3) | ||
eager_out = linear(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后来 bbuff 尝试用 fx 的思路做 lnear 这种带状态的 module 时,遇到了参数 tensor 共享也需要考虑和 oneflow 共享的麻烦。
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
Speed stats:
|
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
oneflow as a backend of torch compile.