Skip to content

Commit

Permalink
Renamed train to fit and made split_index public
Browse files Browse the repository at this point in the history
  • Loading branch information
bluesheeptoken committed Mar 17, 2019
1 parent ea95290 commit 0a7d674
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 36 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ You can test the model with the following code
from cpt.cpt import Cpt
model = Cpt()

model.train([['hello', 'world'],
['hello', 'this', 'is', 'me'],
['hello', 'me']
])
model.fit([['hello', 'world'],
['hello', 'this', 'is', 'me'],
['hello', 'me']
])

model.predict([['hello'], ['hello', 'this']])
# Output: ['me', 'is']
Expand Down Expand Up @@ -49,7 +49,7 @@ import pickle


model = Cpt()
model.train([['hello', 'world']])
model.fit([['hello', 'world']])

dumped = pickle.dumps(model)

Expand Down
3 changes: 2 additions & 1 deletion cpt/cpt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ cdef class Cpt:

cpdef predict(self, list sequences, float noise_ratio=*, int MBR=*, bint multithreading=*)

cdef public int split_index

cdef readonly:
int split_index
Alphabet alphabet
size_t number_trained_sequences
18 changes: 9 additions & 9 deletions cpt/cpt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ cdef class Cpt:
number_trained_sequences : int
the number of sequences used for training
'''
def __cinit__(self, int split_length=0):
if split_length < 0:
raise ValueError('split_length value should be non-negative, actual value: {}'.format(split_length))
def __init__(self, int split_index=0):
if split_index < 0:
raise ValueError('split_index value should be non-negative, actual value: {}'.format(split_index))
self.tree = PredictionTree()
self.inverted_index = vector[Bitset]()
self.lookup_table = vector[Node]()
self.split_index = -split_length
self.split_index = -split_index
self.alphabet = Alphabet()
self.number_trained_sequences = 0

def train(self, sequences):
def fit(self, sequences):
'''Train the model
The model can be retrained to add new sequences
``model.train(seq1);model.train(seq2)`` is equivalent to
``model.train(seq1 + seq2)`` with seq1, seq2 list of sequences
``model.fit(seq1);model.fit(seq2)`` is equivalent to
``model.fit(seq1 + seq2)`` with seq1, seq2 list of sequences
Parameters
----------
Expand All @@ -60,7 +60,7 @@ cdef class Cpt:
Examples
--------
>>> model.train([['hello', 'world'], ['hello', 'cpt']])
>>> model.fit([['hello', 'world'], ['hello', 'cpt']])
'''
cdef size_t number_sequences_to_train = len(sequences)
cdef Node current
Expand Down Expand Up @@ -121,7 +121,7 @@ cdef class Cpt:
>>> model = Cpt()
>>> model.train([['hello', 'world'],
>>> model.fit([['hello', 'world'],
['hello', 'this', 'is', 'me'],
['hello', 'me']
])
Expand Down
4 changes: 2 additions & 2 deletions profiling/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def profile(mode, data_path, profile_path):
cpt = Cpt()

if mode == 'predict':
cpt.train(data)
cpt.fit(data)
data = [sequence[-10:] for sequence in data]
cProfile.runctx('cpt.predict(data, 0.2, 10)', None, locals(), profile_path)
else:
cProfile.runctx('cpt.train(data)', None, locals(), profile_path)
cProfile.runctx('cpt.fit(data)', None, locals(), profile_path)


if __name__ == '__main__':
Expand Down
38 changes: 19 additions & 19 deletions tests/test_cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def setUpClass(cls):
['B', 'C', 'D'],
['B', 'D', 'E']]

cls.cpt.train(cls.sequences)
cls.cpt.fit(cls.sequences)

def test_init(self):
with self.assertRaises(ValueError):
Cpt(-5)

def test_train(self):
def test_fit(self):
alphabet = Alphabet()
alphabet.length = 5
alphabet.indexes = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_predict(self):

def test_richcmp(self):
cpt_wrong_split_index = Cpt(1)
cpt_wrong_split_index.train(self.sequences)
cpt_wrong_split_index.fit(self.sequences)
self.assertNotEqual(self.cpt, cpt_wrong_split_index)
self.assertEqual(self.cpt, self.cpt)

Expand All @@ -63,34 +63,34 @@ def test_pickle(self):
unpickled_cpt = pickle.loads(pickled)
self.assertEqual(self.cpt, unpickled_cpt)

def test_retrain(self):
def test_refit(self):
'''
The bitset is coded on 8 bits,
we need to train with at least 9 sequences to test the resize method
'''
model_no_retrain = Cpt()
model_no_retrain.train([['C', 'P', 'T', '1'],
model_no_retrain.fit([['C', 'P', 'T', '1'],
['C', 'P', 'T', '2'],
['C', 'P', 'T', '3'],
['C', 'P', 'T', '4'],
['C', 'P', 'T', '5'],
['C', 'P', 'T', '6'],
['C', 'P', 'T', '7'],
['C', 'P', 'T', '8'],
['C', 'P', 'T', '9']
])

model_with_retrain = Cpt()
model_with_retrain.fit([['C', 'P', 'T', '1'],
['C', 'P', 'T', '2'],
['C', 'P', 'T', '3'],
['C', 'P', 'T', '4'],
['C', 'P', 'T', '5'],
['C', 'P', 'T', '6'],
['C', 'P', 'T', '7'],
['C', 'P', 'T', '8'],
['C', 'P', 'T', '9']
['C', 'P', 'T', '8']
])

model_with_retrain = Cpt()
model_with_retrain.train([['C', 'P', 'T', '1'],
['C', 'P', 'T', '2'],
['C', 'P', 'T', '3'],
['C', 'P', 'T', '4'],
['C', 'P', 'T', '5'],
['C', 'P', 'T', '6'],
['C', 'P', 'T', '7'],
['C', 'P', 'T', '8']
])
model_with_retrain.train([['C', 'P', 'T', '9']])
model_with_retrain.fit([['C', 'P', 'T', '9']])

self.assertEqual(model_no_retrain, model_with_retrain)

Expand Down

0 comments on commit 0a7d674

Please sign in to comment.