Skip to content

Commit

Permalink
style & typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Aug 23, 2023
1 parent 4eb3e2f commit 6c477e7
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
4 changes: 1 addition & 3 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from copy import deepcopy
from typing import Callable, List

import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from rdkit import Chem, RDLogger
from rdkit import RDLogger
from torch.utils.data import Dataset, IterableDataset

from gflownet.data.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -222,7 +221,6 @@ def __iter__(self):
), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1"
# The task may decide some of the mols are invalid, we have to again filter those
valid_idcs = valid_idcs[m_is_valid]
valid_mols = [m for m, v in zip(mols, m_is_valid) if v]
pred_reward = torch.zeros((num_online, online_flat_rew.shape[1]))
pred_reward[valid_idcs - num_offline] = online_flat_rew
is_valid[num_offline:] = False
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ class GraphBuildingEnvContext:
"""A context class defines what the graphs are, how they map to and from data"""

device: torch.device
action_type_order: List[GraphActionType]

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction
Expand Down
14 changes: 9 additions & 5 deletions src/gflownet/envs/seq_building_env.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Any, List, Tuple
from copy import deepcopy
from typing import Any, List, Tuple

import torch
import torch_geometric.data as gd
from torch_geometric.data import Data
from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Data

from gflownet.envs.graph_building_env import Graph, GraphAction
from .graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnv, GraphBuildingEnvContext
from gflownet.envs.graph_building_env import (
Graph,
GraphAction,
GraphActionType,
GraphBuildingEnv,
GraphBuildingEnvContext,
)


# For typing's sake, we'll pretend that a sequence is a graph.
Expand Down
7 changes: 3 additions & 4 deletions src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

from gflownet.config import Config
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext
from gflownet.envs.seq_building_env import SeqBatch
from gflownet.config import Config


class MLP(nn.Module):
Expand Down
6 changes: 1 addition & 5 deletions src/gflownet/tasks/toy_seq.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import shutil
import socket
from typing import Callable, Dict, List, Tuple, Union
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor
from torch.utils.data import Dataset

from gflownet.config import Config
from gflownet.envs.seq_building_env import GenericSeqBuildingContext, SeqBuildingEnv
Expand Down

0 comments on commit 6c477e7

Please sign in to comment.