diff --git a/train_network.py b/train_network.py index cf64c894..8be8305c 100644 --- a/train_network.py +++ b/train_network.py @@ -134,6 +134,8 @@ def train(args): gc.collect() # prepare network + import sys + sys.path.append(os.path.dirname(__file__)) print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module)