Skip to content
This repository has been archived by the owner on May 22, 2019. It is now read-only.

Commit

Permalink
Merge pull request #372 from EgorBu/master
Browse files Browse the repository at this point in the history
Several fixes for topic modeling pipeline
  • Loading branch information
vmarkovtsev authored Feb 14, 2019
2 parents 18be436 + bce4238 commit f1da33f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Cython>=0.28,<1.0; python_version == '3.7'
PyStemmer==1.3.0
bblfsh==2.12.7
modelforge==0.11.0
sourced-engine==0.7.0
sourced-jgit-spark-connector==2.0.1
parquet==1.2
numpy==1.15.4
humanize==0.5.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"PyStemmer>=1.3,<2.0",
"bblfsh>=2.2.1,<3.0",
"modelforge>=0.11.0,<0.12",
"sourced-engine>=0.7.0,<1.1",
"sourced-jgit-spark-connector>=2.0.1,<2.1.0",
"humanize>=0.5.0,<0.6",
"parquet>=1.2,<2.0",
"pygments>=2.2.0,<3.0",
Expand Down
6 changes: 3 additions & 3 deletions sourced/ml/cmd/repos2bow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def keymap(r):
.execute()
uast_extractor.unpersist()
tokens = {row.token for row in tokens}
reduced_token_freq = {key: df_model[key] for key in df_model.df if key in tokens}
reduced_token_index = {key: df_model.order[key] for key in df_model.df if key in tokens}
reduced_token_freq = {key: df_model[key] for key in df_model._df if key in tokens}
reduced_token_index = {key: df_model.order[key] for key in df_model._df if key in tokens}
log.info("Processing %s distinct tokens", len(reduced_token_freq))
log.info("Indexing by document and token ...")
bags_writer = bags \
.link(TFIDF(reduced_token_freq, df_model.docs, root.session.sparkContext)) \
.link(TFIDF(reduced_token_freq, df_model.docs, root.engine.session.sparkContext)) \
.link(document_indexer) \
.link(Indexer(Uast2BagFeatures.Columns.token, reduced_token_index))
if save_hook is not None:
Expand Down
2 changes: 1 addition & 1 deletion sourced/ml/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def test_bblfsh_dependency(self):
"spark.tech.sourced.bblfsh.grpc.host=localhost")

def test_engine_dependencies(self):
self.assertEqual(get_engine_package("latest"), "tech.sourced:engine:latest")
self.assertEqual(get_engine_package("latest"), "tech.sourced:jgit-spark-connector:latest")
17 changes: 11 additions & 6 deletions sourced/ml/transformers/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,16 @@ def index_column(row):
Please contact me if you have troubles: kslavnov@gmail.com
"""
if isinstance(column_name, str):
assert isinstance(row, Row)
row_dict = row.asDict()
row_dict[column_name] = column2id.value[row_dict[column_name]]
return Row(**row_dict)
return row[:column_name] + (column2id.value[row[column_name]],) + row[column_name + 1:]
indexed_rdd = rdd.map(index_column)
try:
assert isinstance(row, Row)
row_dict = row.asDict()
row_dict[column_name] = column2id.value[row_dict[column_name]]
return [Row(**row_dict)]
except KeyError:
return []
return [row[:column_name] + (column2id.value[row[column_name]],) +
row[column_name + 1:]]

indexed_rdd = rdd.flatMap(index_column)
column2id.unpersist(blocking=True)
return indexed_rdd
11 changes: 6 additions & 5 deletions sourced/ml/utils/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

def get_engine_version():
try:
engine = get_distribution("sourced-engine").version
engine = get_distribution("sourced-jgit-spark-connector").version
except DistributionNotFound:
log = logging.getLogger("engine_version")
engine = requests.get("https://api.github.com/repos/src-d/engine/releases/latest") \
engine = requests.get("https://api.github.com/repos/src-d/sourced-jgit-spark-connector/"
"releases/latest") \
.json()["tag_name"].replace("v", "")
log.warning("Engine not found, queried GitHub to get the latest release tag (%s)",
engine)
log.warning("jgit-spark-connector not found, queried GitHub to get the latest release tag "
"(%s)", engine)
return engine


Expand Down Expand Up @@ -54,7 +55,7 @@ def add_engine_args(my_parser, default_packages=None):


def get_engine_package(engine):
return "tech.sourced:engine:" + engine
return "tech.sourced:jgit-spark-connector:" + engine


def get_bblfsh_dependency(bblfsh):
Expand Down

0 comments on commit f1da33f

Please sign in to comment.