Skip to content

sakuranew/PointerNet-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 

Repository files navigation

PointerNet-pytorch

Pointer Network pytorch version,decoder use boundary model

Can be added to bert as output layer Current version support gru only

Test

from PointerNet import PtrNet
import torch
x=torch.randn(2,5,256)
m=PtrNet(256,128,bi=True)
print(m(x))

Output:

tensor([[[0.2175, 0.2090, 0.2347, 0.1744, 0.1645],
         [0.2174, 0.2089, 0.2356, 0.1731, 0.1650]],
        [[0.2420, 0.2518, 0.1624, 0.1893, 0.1545],
         [0.2407, 0.2503, 0.1661, 0.1893, 0.1537]]], grad_fn=<StackBackward>)

With Bert

from PointerNet import PtrNet
import torch
from pytorch_transformers import BertModel
x=torch.range(1,5).unsqueeze(0).long()
bert=BertModel.from_pretrained('base model path')
embd,cls=bert(x)
m=PtrNet(768,128,bi=True)
print(m(embd))

About

Pointer Network pytorch version

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages