From c414d0284ae72fc4681f5ef846f6ef77d3434253 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Fri, 2 Aug 2024 23:12:49 +0200 Subject: [PATCH] Minor refactor --- llmc_py/utils.py | 4 ++++ train_gpt2.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llmc_py/utils.py b/llmc_py/utils.py index 66ff7a42e..ed023c78a 100644 --- a/llmc_py/utils.py +++ b/llmc_py/utils.py @@ -1,3 +1,7 @@ +# Taken from: +# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py +# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py + import torch from torch import nn diff --git a/train_gpt2.py b/train_gpt2.py index 46a264701..b81effe44 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -24,7 +24,6 @@ from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path - from typing import ( List, Optional, @@ -988,6 +987,8 @@ def print0(*args, **kwargs): print(f"{result['generation']}") print("\n==================================\n") + exit(0) # only inference supported for now + # ------------------------------------------------------------------------- # PyTorch -> C bridge: save some weights and state for C to load later as reference