diff --git a/demo/run_dgp_demo.py b/demo/run_dgp_demo.py index 875a6bd..178e730 100644 --- a/demo/run_dgp_demo.py +++ b/demo/run_dgp_demo.py @@ -21,7 +21,7 @@ os.environ["DLClight"] = "True" os.environ["Colab"] = "True" from deeplabcut.utils import auxiliaryfunctions - +from deepgraphpose.contrib.segment_videos import split_video from deepgraphpose.models.fitdgp import fit_dlc, fit_dgp, fit_dgp_labeledonly from deepgraphpose.models.fitdgp_util import get_snapshot_path from deepgraphpose.models.eval import plot_dgp @@ -136,6 +136,8 @@ def get_init_weights_path(base_path): default=10, help="size of the batch, if there are memory issues, decrease it value") parser.add_argument("--test", action='store_true', default=False) + parser.add_argument("--split", action = "store_true",help = "whether or not we should run inference on chopped up videos") + parser.add_argument("--splitlength", default = 6000, help= "number of frames in block if splitting videos. ") input_params = parser.parse_known_args()[0] print(input_params) @@ -146,6 +148,8 @@ def get_init_weights_path(base_path): batch_size = input_params.batch_size test = input_params.test + splitflag,splitlength = input_params.split,input_params.splitlength + update_configs = False if dlcpath == join('data','Reaching-Mackenzie-2018-08-30'): # update config files @@ -283,6 +287,14 @@ def get_init_weights_path(base_path): os.makedirs(video_pred_path) print('video_sets', video_sets, flush=True) + if splitflag: + video_cut_path = str(Path(dlcpath) / 'videos_cut') + if not os.path.exists(video_cut_path): + os.makedirs(video_cut_path) + clip_sets = [] + for v in video_sets: + clip_sets.extend(split_video(v,int(splitlength),suffix = "demo",outputloc = video_cut_path)) + video_sets = clip_sets ## replace video_sets with clipped versions. if test: for video_file in [video_sets[0]]: diff --git a/src/deepgraphpose/contrib/segment_videos.py b/src/deepgraphpose/contrib/segment_videos.py new file mode 100644 index 0000000..3856610 --- /dev/null +++ b/src/deepgraphpose/contrib/segment_videos.py @@ -0,0 +1,85 @@ +import os +from moviepy.editor import VideoFileClip +# Given a DLC project, check all videos and clip those that are longer than some tolerance into shorter disjoint clips. + +# Take the project, and look within for all videos that will be trained on (these come from the config file) and analyzed (these come from the folder videos_dgp). +# Those videos that have labels need to be split on the training labels as well. + +def split_video_and_trainframes(config_path,tol=5000,suffix=None): + """Splits videos and trainframes in a model config that are larger than some tolerance in frames. + + :param config_path: parameter to config file. + :param tol: tolerance in number of frames. + :param suffix: video suffix. + """ + trainvids = check_videos(config_path,tol) + analyzevids = check_analysis_videos(folder,tol) + splitlength = tol + vids = trainvids + analyzevids + for v in vids: + split_video(v,splitlength,suffix) + if v in trainvids: + format_frames(v,splitlength,suffix) + +def check_videos(config_path,tol): + """Checks all videos given in the model cfg file and checks if any are longer than the given length. + + :param config_path: parameter to config file. + :param tol: tolerance in number of frames. + """ + +def check_analysis_videos(folder_path,tol): + """Checks all videos given in the videos_dgp directory and checks if any are longer than the given length. + + :param config_path: parameter to config file. + :param tol: tolerance in number of frames. + """ + +def split_video(vidpath,splitlength,suffix = "",outputloc = None): + """splits a given video into subclips of type mp4. Note: will work best (even frames per subclip) if you pass a splitlength that is divisible by your frame rate. + + :param vidpath: path to video + :param splitlength: length to chunk into in frames + :param suffix: custom suffix to add to subclips + :param outputloc: directory to write outputs to. Default is same directory. + :returns: list of paths to new video files. + """ + try: + clip = VideoFileClip(vidpath) + except FileNotFoundError: + print("file not found.") + + duration = clip.duration + splitlength_secs = splitlength/clip.fps + viddir,vidname = os.path.dirname(vidpath),os.path.basename(vidpath) + base,ext = os.path.splitext(vidname) + subname = base+suffix+"{n}"+".mp4" + if outputloc is None: + subpath = os.path.join(viddir,subname) + else: + subpath = os.path.join(outputloc,subname) + + clipnames = [] + clipstart = 0 + clipind = 0 + while clipstart < duration: + subname = subpath.format(n=clipind) + subclip = clip.subclip(clipstart,min(duration,clipstart+splitlength_secs)) + subclip.write_videofile(subname,codec = "mpeg4") + clipnames.append(subname) + clipstart += splitlength_secs + clipind+=1 + return clipnames + + + + +def format_frames(vidpath,splitlength,suffix = None): + """reformats training frames into format that matches sublclips + + :param vidpath: path to video + :param splitlength: length to chunk into + :param suffix: custom suffix to add to subclips + """ + + diff --git a/tests/test_segment_videos.py b/tests/test_segment_videos.py new file mode 100644 index 0000000..2afb479 --- /dev/null +++ b/tests/test_segment_videos.py @@ -0,0 +1,38 @@ +import pytest +from deepgraphpose.contrib.segment_videos import split_video +from moviepy.editor import VideoFileClip +import math +import os + +here = os.path.abspath(os.path.dirname(__file__)) +testmodel_path = os.path.join(here,"testmodel") + +def test_split_video(tmp_path): + output = tmp_path/"subclips" + output.mkdir() + frame_duration = 30 + vidpath = os.path.join(testmodel_path,"videos","reachingvideo1.avi") + video_locs = split_video(vidpath,frame_duration,suffix = "test",outputloc = str(output)) + + origclip = VideoFileClip(vidpath) + duration = origclip.duration*origclip.fps + assert len(video_locs) == math.ceil(duration/frame_duration) + vid_inds = [] + for vi in video_locs: + prefix = os.path.splitext(os.path.basename(vidpath))[0]+"test" + assert os.path.splitext(os.path.basename(vi))[0].startswith(prefix) + vid_inds.append(int(vi.split(prefix)[-1].split(".mp4")[0])) + sub = VideoFileClip(vi) + assert sub.duration*sub.fps - frame_duration < 1e-1 + assert set(vid_inds) == set(range(len(video_locs))) + + + + + + + + + + +