Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat_ppw2' into feat_rhand
Browse files Browse the repository at this point in the history
  • Loading branch information
jtwhite79 committed Jan 6, 2025
2 parents 5c480d1 + 10a9129 commit 1366926
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions pyemu/utils/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def start_workers(
args = (os.path.join(worker_dir,pst_rel_path),hostname,port)
for i in range(num_workers):
p = mp.Process(target=ppw_function,args=args,kwargs=ppw_kwargs)
p.daemon = True
p.start()
procs.append(p)

Expand Down Expand Up @@ -674,14 +675,15 @@ def __init__(self, pst, host, port, timeout=0.1,verbose=True):
self.obs_names = None

self.par_values = None

self.max_reconnect_attempts = 10
self._process_pst()
self.connect()
self._lock = threading.Lock()
self._send_lock = threading.Lock()
self._listen_thread = threading.Thread(target=self.listen,args=(self._lock,self._send_lock))
self._listen_thread.start()


def _process_pst(self):
if isinstance(self._pst_arg,str):
self._pst = pst_handler.Pst(self._pst_arg)
Expand All @@ -692,7 +694,7 @@ def _process_pst(self):
format(type(self._pst_arg)))


def connect(self):
def connect(self,is_reconnect=False):
self.message("trying to connect to {0}:{1}...".format(self.host,self.port))
self.s = None
c = 0
Expand All @@ -703,6 +705,10 @@ def connect(self):
c += 1
if c % 75 == 0:
print('')
print(c)
if is_reconnect and c > self.max_reconnect_attempts:
print("max reconnect attempts reached...")
return False
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.connect((self.host, self.port))
self.message("connected to {0}:{1}".format(self.host,self.port))
Expand All @@ -712,6 +718,10 @@ def connect(self):
continue
except Exception as e:
continue

self.net_pack = NetPack(timeout=self.timeout,verbose=self.verbose)
return True


def message(self,msg):
if self.verbose:
Expand All @@ -726,18 +736,31 @@ def recv(self,dtype=None):


def send(self,mtype,group,runid,desc="",data=0):
self.net_pack.send(self.s,mtype,group,runid,desc,data)
try:
self.net_pack.send(self.s,mtype,group,runid,desc,data)
except Exception as e:
print("WARNING: error sending message:{0}".format(str(e)))
return False
self.message("sent message type:{0}".format(NetPack.netpack_type[mtype]))
return True

def listen(self,lock=None,send_lock=None):
self.s.settimeout(self.timeout)
failed_reconnect = False
while True:
time.sleep(self.timeout)
try:
n = self.recv()
except Exception as e:
print("WARNING: recv exception:"+str(e)+"...trying to reconnect...")
self.connect()
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
time.sleep(self.timeout)
return
else:
print("...reconnect successfully...")
continue

if n > 0:
# need to sync here
Expand Down Expand Up @@ -776,20 +799,40 @@ def listen(self,lock=None,send_lock=None):
elif self.net_pack.mtype == 6:
if self._send_lock is not None:
self._send_lock.acquire()
self.send(7, self.net_pack.group,
success = self.send(7, self.net_pack.group,
self.net_pack.runid,
"fake linpack result", data=1)
if self._send_lock is not None:
self._send_lock.release()
if not success:
print("...trying to reconnect...")
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
time.sleep(self.timeout)
return
else:
print("reconnect successfully...")
continue

elif self.net_pack.mtype == 15:
if self._send_lock is not None:
self._send_lock.acquire()
self.send(15, self.net_pack.group,
sucess = self.send(15, self.net_pack.group,
self.net_pack.runid,
"ping back")
if self._send_lock is not None:
self._send_lock.release()
if not success:
print("...trying to reconnect...")
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
time.sleep(self.timeout)
return
else:
print("reconnect successfully...")
continue
elif self.net_pack.mtype == 14:
#print("recv'd terminate signal")
self.message("recv'd terminate signal")
Expand Down Expand Up @@ -819,6 +862,7 @@ def get_parameters(self):
raise Exception("len(par vals) {0} != len(par names)".format(len(pars),len(self.par_names)))
return pd.Series(data=pars,index=self.par_names)


def send_observations(self,obsvals,parvals=None,request_more_pars=True):
if len(obsvals) != len(self.obs_names):
raise Exception("len(obs vals) {0} != len(obs names)".format(len(obsvals), len(self.obs_names)))
Expand Down Expand Up @@ -862,11 +906,13 @@ def send_observations(self,obsvals,parvals=None,request_more_pars=True):
self.send(3,0,0,"ready for next run",data=0)
self._send_lock.release()


def request_more_pars(self):
self._send_lock.acquire()
self.send(3, 0, 0, "ready for next run", data=0.0)
self._send_lock.release()


def send_failed_run(self,group=None,runid=None,desc="failed"):
if group is None:
group = self.net_pack.group
Expand All @@ -876,6 +922,7 @@ def send_failed_run(self,group=None,runid=None,desc="failed"):
self.send(12, int(group), int(runid), desc, data=0.0)
self._send_lock.release()


def send_killed_run(self,group=None,runid=None,desc="killed"):
if group is None:
group = self.net_pack.group
Expand Down

0 comments on commit 1366926

Please sign in to comment.