wip: prefix tree
This commit is contained in:
parent
e82bc7490b
commit
50fac496ea
3 changed files with 99 additions and 32 deletions
|
@ -64,26 +64,26 @@ class Indexer(IndexerBase):
|
|||
|
||||
def query(self, query: str):
|
||||
start_time = perf_counter()
|
||||
trigram_results = self._trigram_index.query(query)
|
||||
|
||||
logger.info(f"Narrowed down to {len(trigram_results)} files via trigram search")
|
||||
|
||||
leads = self._trigram_index.query(query)
|
||||
logger.info(f"Narrowed down to {len(leads)} files via trigram search")
|
||||
confirmed = []
|
||||
uniques = 0
|
||||
for lead in trigram_results:
|
||||
lead_content = self.corpus.get_document(key=lead[1]).content
|
||||
for lead in leads:
|
||||
uid, score = lead
|
||||
lead_content = self.corpus.get_document(uid=uid).content
|
||||
lead_path = self.corpus.get_document(uid=uid).key
|
||||
results = re.finditer(query, lead_content)
|
||||
hits_in_lead = []
|
||||
for hit in results:
|
||||
start_line, end_line = self._find_line_range(
|
||||
lead[1], hit.start(), hit.end()
|
||||
lead_path, hit.start(), hit.end()
|
||||
)
|
||||
start_offset = self._line_index.query(lead[1])[start_line][0]
|
||||
end_offset = self._line_index.query(lead[1])[end_line][1]
|
||||
start_offset = self._line_index.query(lead_path)[start_line][0]
|
||||
end_offset = self._line_index.query(lead_path)[end_line][1]
|
||||
|
||||
hits_in_lead.append(
|
||||
SearchResult(
|
||||
key=lead[1],
|
||||
key=lead_path,
|
||||
offset_start=start_offset,
|
||||
offset_end=end_offset,
|
||||
line_start=start_line,
|
||||
|
@ -136,22 +136,27 @@ class Indexer(IndexerBase):
|
|||
processes = settings.INDEXING_PROCESSES
|
||||
pool = Pool(processes=processes)
|
||||
chunks = chunkify_content(uids, processes)
|
||||
results = pool.map(self._bulk_process, chunks)
|
||||
# TODO: Refactor indices to populate cleanly.
|
||||
for result in results:
|
||||
self._trigram_index._trigrams.update(result[0])
|
||||
processed_chunks = pool.map(self._bulk_process, chunks)
|
||||
|
||||
for result in processed_chunks:
|
||||
for uid in result[0]:
|
||||
self._trigram_index.index(
|
||||
uid.replace("document:", ""), None, None, result[0][uid]
|
||||
)
|
||||
self._line_index._lines.update(result[1])
|
||||
|
||||
# TODO: Tidy up, rethink w.r.t. multiprocessing.
|
||||
def _bulk_process(self, uids: List[str]):
|
||||
trigrams = {}
|
||||
for uid in uids:
|
||||
document = self.corpus.get_document(uid=uid)
|
||||
path = document.key
|
||||
content = document.content
|
||||
self._trigram_index.index(path, content)
|
||||
trigrams[uid] = TrigramIndex.trigramize(content)
|
||||
self._line_index.index(path, content)
|
||||
logger.info(f"[Indexing] Processed {path}")
|
||||
|
||||
return (self._trigram_index._trigrams, self._line_index._lines)
|
||||
return (trigrams, self._line_index._lines)
|
||||
|
||||
def _find_closest_line(self, path, index):
|
||||
content = self._line_index.query(path)
|
||||
|
|
51
src/prefix_tree.py
Normal file
51
src/prefix_tree.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import attr
|
||||
|
||||
|
||||
@attr.s
|
||||
class PrefixTree:
|
||||
root = attr.ib()
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
root = PrefixTreeNode(value=None)
|
||||
return PrefixTree(root=root)
|
||||
|
||||
def insert(self, value, key, current=None):
|
||||
if current is None:
|
||||
current = self.root
|
||||
|
||||
if not value:
|
||||
current.mappings.append(key)
|
||||
return
|
||||
top = value[0]
|
||||
rest = value[1:]
|
||||
|
||||
next_child = current.children.get(top)
|
||||
|
||||
if next_child:
|
||||
self.insert(rest, key, next_child)
|
||||
else:
|
||||
new_node = PrefixTreeNode(value=top)
|
||||
current.children[top] = new_node
|
||||
self.insert(rest, key, new_node)
|
||||
|
||||
def get(self, value, current=None):
|
||||
if not current:
|
||||
current = self.root
|
||||
if not value:
|
||||
return current.mappings
|
||||
|
||||
top = value[0]
|
||||
rest = value[1:]
|
||||
|
||||
next_child = current.children.get(top)
|
||||
|
||||
if next_child:
|
||||
return self.get(rest, next_child)
|
||||
|
||||
|
||||
@attr.s
|
||||
class PrefixTreeNode:
|
||||
value = attr.ib()
|
||||
mappings = attr.ib(default=attr.Factory(list))
|
||||
children = attr.ib(default=attr.Factory(dict))
|
|
@ -3,34 +3,45 @@ from typing import List, Optional
|
|||
import attr
|
||||
from settings import settings
|
||||
from base import IndexBase
|
||||
from prefix_tree import PrefixTree
|
||||
|
||||
|
||||
@attr.s
|
||||
class TrigramIndex(IndexBase):
|
||||
_trigrams = attr.ib(default=attr.Factory(dict))
|
||||
_threshold = attr.ib(default=settings.SIGNIFICANCE_THRESHOLD)
|
||||
_tree = attr.ib(attr.Factory(PrefixTree.initialize))
|
||||
|
||||
def index(self, key: str, content: str):
|
||||
self._trigrams[key] = self._trigramize(content)
|
||||
def index(self, uid, key: str, content: str, trigrams):
|
||||
if content:
|
||||
trigrams = TrigramIndex.trigramize(content)
|
||||
|
||||
for trigram in trigrams:
|
||||
self._tree.insert(trigram, uid)
|
||||
|
||||
def query(self, query: str, haystack: Optional[List[str]] = None) -> List[str]:
|
||||
if not haystack:
|
||||
haystack = self._trigrams
|
||||
query_trigrams = TrigramIndex.trigramize(query)
|
||||
results = {}
|
||||
|
||||
query_trigrams = self._trigramize(query)
|
||||
results = []
|
||||
for trigram in query_trigrams:
|
||||
result_set = self._tree.get(trigram)
|
||||
if result_set:
|
||||
results[trigram] = result_set
|
||||
|
||||
for item in haystack:
|
||||
shared = self._trigrams[item].intersection(query_trigrams)
|
||||
ratio = len(shared) / len(query_trigrams)
|
||||
if ratio < self._threshold:
|
||||
continue
|
||||
matches = {}
|
||||
|
||||
results.append((ratio, item, list(shared)))
|
||||
for result in results:
|
||||
for doc in results[result]:
|
||||
matches[doc] = matches.get(doc, 0) + 1
|
||||
|
||||
results.sort(reverse=True, key=lambda x: x[0])
|
||||
significant_results = []
|
||||
for uid, occurrences in matches.items():
|
||||
score = occurrences / len(query_trigrams)
|
||||
if score >= self._threshold:
|
||||
significant_results.append((f"document:{uid}", score))
|
||||
|
||||
return results
|
||||
significant_results.sort(reverse=True, key=lambda x: x[0])
|
||||
return significant_results
|
||||
|
||||
def _trigramize(self, content: str) -> List[str]:
|
||||
@staticmethod
|
||||
def trigramize(content: str) -> List[str]:
|
||||
return {content[pos : pos + 3].lower() for pos in range(len(content) - 2)}
|
||||
|
|
Reference in a new issue