diff --git a/.gitignore b/.gitignore index b6e4761..b282e51 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +*.sw[a-z] # C extensions *.so diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..64785ab --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +appdirs==1.4.4 +attr==0.3.1 +attrs==20.2.0 +black==20.8b1 +click==7.1.2 +iniconfig==1.0.1 +invoke==1.4.1 +more-itertools==8.5.0 +mypy-extensions==0.4.3 +packaging==20.4 +pathspec==0.8.0 +pluggy==0.13.1 +py==1.9.0 +pyinotify==0.9.6 +pyparsing==2.4.7 +pytest==6.0.2 +regex==2020.7.14 +six==1.15.0 +toml==0.10.1 +typed-ast==1.4.1 +typing-extensions==3.7.4.3 diff --git a/script/bootstrap b/script/bootstrap new file mode 100644 index 0000000..e644547 --- /dev/null +++ b/script/bootstrap @@ -0,0 +1,15 @@ +VENV=codesearch.venv + +################################################################# +# Bootstrapping sets up the Python 3.8 venv that allows the use # +# of the invoke commands. # +################################################################# + +{ + pyenv virtualenv-delete -f $VENV + pyenv virtualenv $VENV && + pyenv activate $VENV && + python -m pip install -U pip && + pip install -r requirements.txt && + echo "✨ Good to go! ✨" +} diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/__snapshots__/test_prefix_tree.ambr b/src/__snapshots__/test_prefix_tree.ambr new file mode 100644 index 0000000..1ddf955 --- /dev/null +++ b/src/__snapshots__/test_prefix_tree.ambr @@ -0,0 +1,184 @@ +# name: test_base_tree_has_a_root_node + { + 'children': [ + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_insert_multiple_keys_same_string + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + ], + 'mappings': [ + 'key_1', + 'key_2', + ], + 'value': 'd', + }, + ], + 'mappings': [ + ], + 'value': 'c', + }, + ], + 'mappings': [ + ], + 'value': 'b', + }, + ], + 'mappings': [ + ], + 'value': 'a', + }, + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_insert_overlapping_strings + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + ], + 'mappings': [ + 'key_1', + ], + 'value': 'd', + }, + { + 'children': [ + ], + 'mappings': [ + 'key_2', + ], + 'value': 'e', + }, + ], + 'mappings': [ + ], + 'value': 'c', + }, + ], + 'mappings': [ + ], + 'value': 'b', + }, + ], + 'mappings': [ + ], + 'value': 'a', + }, + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_insert_single_character_ + { + 'children': [ + { + 'children': [ + ], + 'mappings': [ + 'key_1', + ], + 'value': 'a', + }, + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_insert_single_string + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + ], + 'mappings': [ + 'key_1', + ], + 'value': 'c', + }, + ], + 'mappings': [ + ], + 'value': 'b', + }, + ], + 'mappings': [ + ], + 'value': 'a', + }, + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_insert_strings_subsets_of_each_other + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + { + 'children': [ + ], + 'mappings': [ + 'key_1', + ], + 'value': 'd', + }, + ], + 'mappings': [ + 'key_2', + ], + 'value': 'c', + }, + ], + 'mappings': [ + ], + 'value': 'b', + }, + ], + 'mappings': [ + ], + 'value': 'a', + }, + ], + 'mappings': [ + ], + 'value': None, + } +--- +# name: test_serializes_to_json + '{"value": null, "mappings": [], "children": [{"value": "a", "mappings": [], "children": [{"value": "b", "mappings": [], "children": [{"value": "c", "mappings": [], "children": [{"value": "d", "mappings": ["key_1"], "children": []}]}]}]}]}' +--- diff --git a/src/codesearch/__init__.py b/src/codesearch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/codesearch/base.py b/src/codesearch/base.py new file mode 100644 index 0000000..7fc2d3c --- /dev/null +++ b/src/codesearch/base.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + + +class IndexBase(ABC): + @abstractmethod + def index(self, content: str, haystack: Optional[List[str]]): + pass + + @abstractmethod + def query(self, query: str) -> List[str]: + pass + + +class IndexerBase(ABC): + @abstractmethod + def index(self, paths: List[str]): + pass diff --git a/src/codesearch/cli.py b/src/codesearch/cli.py new file mode 100644 index 0000000..5f66ddb --- /dev/null +++ b/src/codesearch/cli.py @@ -0,0 +1,30 @@ +import argparse + +from pathlib import Path +from .server import Server +from .indexer import Indexer +from .client import search +from .settings import settings + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("command") + parser.add_argument("--q", required=False) + + args = parser.parse_args() + + if args.command == "start": + watched = [Path(p).expanduser() for p in settings.WATCHED] + server = Server( + indexer=Indexer( + domain=watched, + exclusions=settings.EXCLUDES, + file_types=settings.FILE_TYPES, + ), + watched=watched, + ) + server.run() + elif args.command == "search": + search(args.q) diff --git a/src/codesearch/client.py b/src/codesearch/client.py new file mode 100644 index 0000000..e5291e4 --- /dev/null +++ b/src/codesearch/client.py @@ -0,0 +1,83 @@ +import socket +import json +from pathlib import Path + +import curses + + +from .settings import settings +from .colors import highlight + + +def display_handler(stdscr, buffer): + current_y = 0 + stdscr.refresh() + curses.start_color() + y, x = stdscr.getmaxyx() + curses.init_pair(1, curses.COLOR_GREEN, curses.COLOR_BLACK) + pad = curses.newpad(y, x) + while True: + row = 0 + y_offset = 0 + pad.clear() + while row < current_y + y - 1: + l = buffer[current_y + y_offset] + if l["type"] == "path": + pad.addstr(row, 0, l["value"], curses.color_pair(1)) + row += 1 + y_offset += 1 + elif l["type"] == "sep": + row += 1 + y_offset += 1 + else: + pad.addstr(row, 0, str(l["lineno"]), curses.color_pair(1)) + pad.addstr(row, 5, l["value"]) + row += 1 + y_offset += 1 + + if y_offset == y or current_y == y - 1: + break + + pad.refresh(0, 0, 0, 0, y, x) + key = stdscr.getch() + + if key in [81, 113]: + break + elif key == curses.KEY_UP: + current_y = max(0, current_y - 1) + elif key == curses.KEY_DOWN: + current_y = min(len(buffer), current_y + 1) + elif key == curses.KEY_NPAGE: + current_y = min(len(buffer), current_y + y + 1) + elif key == curses.KEY_PPAGE: + current_y = max(0, current_y - y - 1) + + +def search(query): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect((settings.SOCKET_HOST, settings.SOCKET_PORT)) + s.sendall(query.encode()) + length = int(s.recv(8).decode()) + results = None + + with open(Path(settings.BUFFER_PATH).expanduser(), "rb") as infile: + results = infile.read().decode() + + results = json.loads(results) + + output = [] + for result in results: + with open(result["key"], "r") as infile: + highlighted_text = infile.read()[ + result["offset_start"] : result["offset_end"] + ] + line_number = result["line_start"] + output.append({"value": result["key"], "type": "path"}) + for l in highlighted_text.split("\n"): + output.append({"value": l, "type": "code", "lineno": line_number}) + line_number += 1 + output.append({"type": "sep"}) + + s.close() + + curses.wrapper(display_handler, output) diff --git a/src/codesearch/colors.py b/src/codesearch/colors.py new file mode 100644 index 0000000..98a10d6 --- /dev/null +++ b/src/codesearch/colors.py @@ -0,0 +1,7 @@ +COLORS = {"green": "\033[92m", "yellow": "\033[93m", "red": "\033[91m"} +ENDC = "\033[0m" + + +def highlight(text, color="green"): + color_code = COLORS[color] + return f"{color_code}{text}{ENDC}" diff --git a/src/codesearch/constants.py b/src/codesearch/constants.py new file mode 100644 index 0000000..9a14ce5 --- /dev/null +++ b/src/codesearch/constants.py @@ -0,0 +1,11 @@ +SETTINGS_KEYS = [ + "WATCHED", + "SOCKET_PORT", + "SOCKET_HOST", + "EXCLUDES", + "FILE_TYPES", + "SIGNIFICANCE_THRESHOLD", + "INDEXING_PROCESSES", + "BUFFER_PATH", +] +QUERY_STRING_LENGTH = 1024 diff --git a/src/codesearch/document_models.py b/src/codesearch/document_models.py new file mode 100644 index 0000000..1652a33 --- /dev/null +++ b/src/codesearch/document_models.py @@ -0,0 +1,52 @@ +import attr + + +@attr.s +class Corpus: + _documents = attr.ib(default=attr.Factory(dict)) + _key_to_uid = attr.ib(default=attr.Factory(dict)) + + @property + def document_count(self): + return len(self._documents) + + def add_document(self, key, content): + document_uid = f"document:{self.document_count}" + + self._documents[document_uid] = Document( + uid=document_uid, key=key, content=content + ) + self._key_to_uid[key] = document_uid + + return document_uid + + def get_document(self, uid=None, key=None): + if key: + uid = self._key_to_uid[key] + + return self._documents[uid] + + def collect_unprocessed_documents(self): + return [ + uid + for uid in self._documents + if not self.get_document(uid=uid).is_processed + ] + + def mark_document_as_processed(self, uid): + self._documents[uid].mark_as_processed() + + +@attr.s +class Document: + uid = attr.ib() + key = attr.ib() + content = attr.ib() + _processed = attr.ib(default=False) + + @property + def is_processed(self): + return self._processed + + def mark_as_processed(self): + self._processed = True diff --git a/src/codesearch/indexer.py b/src/codesearch/indexer.py new file mode 100644 index 0000000..c4090fa --- /dev/null +++ b/src/codesearch/indexer.py @@ -0,0 +1,211 @@ +from .base import IndexerBase +from pathlib import Path +from typing import Dict, List +import re +from time import perf_counter +from multiprocessing import Pool +import mmap + +import attr + +from .settings import settings + +from .process_utils import chunkify_content +from .document_models import Corpus +from .trigram_index import TrigramIndex +from .line_index import LineIndex +from .logger import get_logger + +logger = get_logger(__name__) + + +@attr.s +class SearchResult: + key = attr.ib() + offset_start = attr.ib() + offset_end = attr.ib() + line_start = attr.ib() + line_end = attr.ib() + + def to_dict(self): + return { + "key": self.key, + "offset_start": self.offset_start, + "offset_end": self.offset_end, + "line_start": self.line_start, + "line_end": self.line_end, + } + + +@attr.s +class Indexer(IndexerBase): + # Indices + _trigram_index = attr.ib(default=attr.Factory(TrigramIndex)) + _line_index = attr.ib(default=attr.Factory(LineIndex)) + + _exclusions = attr.ib(default=attr.Factory(list)) + _file_types = attr.ib(default=attr.Factory(list)) + # Document corpus + corpus = attr.ib(default=attr.Factory(Corpus)) + domain = attr.ib(default=attr.Factory(list)) + + def index(self, paths: List[str]): + start_time = perf_counter() + discovered = [] + for path in paths: + discovered.extend(self._discover(path)) + + logger.info(f"Discovered {len(discovered)} files.", prefix="Discovery") + + self._build_corpus(discovered) + self._populate_indices(self.corpus.collect_unprocessed_documents()) + end_time = perf_counter() + + logger.info( + f"{self.corpus.document_count} total files indexed in {end_time - start_time} seconds.", + prefix="Index status", + ) + + def query(self, query: str): + start_time = perf_counter() + leads = self._trigram_index.query(query) + logger.info( + f"Narrowed down to {len(leads)} files via trigram search", prefix="Query" + ) + confirmed = [] + uniques = 0 + for lead in leads: + uid, score = lead + lead_path = self.corpus.get_document(uid=uid).key + lead_content = "" + try: + with open(lead_path, "r") as infile: + import mmap + + m = mmap.mmap(infile.fileno(), 0, prot=mmap.PROT_READ) + lead_content = m.read().decode() + except Exception as e: + logger.warning(e) + logger.warning(f"No content in {lead_path}", prefix="Query") + + results = re.finditer(query, lead_content) + hits_in_lead = [] + for hit in results: + start_line, end_line = self._find_line_range( + lead_path, hit.start(), hit.end() + ) + start_offset = self._line_index.query(lead_path)[start_line][0] + end_offset = self._line_index.query(lead_path)[end_line][1] + + hits_in_lead.append( + SearchResult( + key=lead_path, + offset_start=start_offset, + offset_end=end_offset, + line_start=start_line, + line_end=end_line, + ) + ) + + if hits_in_lead: + confirmed.extend(hits_in_lead) + uniques += 1 + end_time = perf_counter() + logger.info( + f"{len(confirmed)} hits in {uniques} files ({end_time - start_time} seconds elapsed).", + prefix="Query", + ) + return [r.to_dict() for r in confirmed] + + def _discover(self, path_root: str) -> Dict[str, str]: + collected = [] + current = Path(path_root) + + # Avoid any excluded paths + if any([current.match(x) for x in self._exclusions]): + logger.info(f"{path_root} excluded.", prefix="Discovery") + return [] + + if current.is_dir(): + for child_path in current.iterdir(): + collected.extend(self._discover(str(child_path))) + + return collected + + if current.suffix not in self._file_types: + return [] + + logger.info(f"Collected {path_root}", prefix="Discovery") + return [path_root] + + def _build_corpus(self, discovered: List[str]): + total = len(discovered) + current = 0 + for discovered_file in discovered: + self.corpus.add_document(key=discovered_file, content="") + current += 1 + logger.info( + f"({current}/{total}) Registered {discovered_file} in corpus", + prefix="Corpus building", + ) + + def _populate_indices(self, uids): + processes = settings.INDEXING_PROCESSES + pool = Pool(processes=processes) + chunks = chunkify_content(uids, processes) + processed_chunks = pool.map(self._bulk_process, chunks) + + for result in processed_chunks: + for uid in result[0]: + self._trigram_index.index( + uid.replace("document:", ""), None, None, result[0][uid] + ) + self._line_index._lines.update(result[1]) + + # TODO: Tidy up, rethink w.r.t. multiprocessing. + def _bulk_process(self, uids: List[str]): + trigrams = {} + total = len(uids) + current = 0 + for uid in uids: + document = self.corpus.get_document(uid=uid) + path = document.key + try: + with open(path, "r") as document_file: + mapped_file = mmap.mmap( + document_file.fileno(), 0, prot=mmap.PROT_READ + ) + content = mapped_file.read().decode() + trigrams[uid] = TrigramIndex.trigramize(content) + self._line_index.index(path, content) + current += 1 + logger.info( + f"({current}/{total}) Processed {path}", prefix="Indexing" + ) + except Exception as e: + logger.info(e) + current += 1 + logger.warning( + f"({current}/{total}) Could not read {path}, skipping.", + prefix="Indexing", + ) + + return (trigrams, self._line_index._lines) + + def _find_closest_line(self, path, index): + content = self._line_index.query(path) + + for l in content: + if content[l][0] <= index <= content[l][1]: + return l + # TODO: This should not be reachable. + return 0 + + def _find_line_range(self, key, start, end, padding=5): + start_line = self._find_closest_line(key, start) + end_line = self._find_closest_line(key, end) + + start_line_range = max(0, start_line - 5) + end_line_range = min(len(self._line_index.query(key)) - 1, end_line + 5) + + return (start_line_range, end_line_range) diff --git a/src/codesearch/line_index.py b/src/codesearch/line_index.py new file mode 100644 index 0000000..e672d6f --- /dev/null +++ b/src/codesearch/line_index.py @@ -0,0 +1,22 @@ +from .base import IndexBase +import attr + +from .logger import get_logger + +logger = get_logger(__name__) + + +@attr.s +class LineIndex(IndexBase): + _lines = attr.ib(default=attr.Factory(dict)) + + def index(self, key: str, content: str): + self._lines[key] = {} + current, count = 0, 0 + for line in content.split("\n"): + self._lines[key][count] = (current, current + len(line)) + current += len(line) + count += 1 + + def query(self, key: str): + return self._lines[key] diff --git a/src/codesearch/logger.py b/src/codesearch/logger.py new file mode 100644 index 0000000..b24d5aa --- /dev/null +++ b/src/codesearch/logger.py @@ -0,0 +1,35 @@ +import logging +import sys +import attr + +from .colors import highlight + + +@attr.s +class Logger: + logger = attr.ib() + + def info(self, message, prefix=None): + prefix_str = "" + if prefix: + prefix_str = highlight(f"[{prefix}]", "green") + + self.logger.info(f"{prefix_str} {message}") + + def warning(self, message, prefix=None): + prefix_str = "" + if prefix: + prefix_str = highlight(f"[{prefix}]", "yellow") + + self.logger.warning(f"{prefix_str} {message}") + + +def get_logger(name): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stdout) + logger.addHandler(handler) + + logger_obj = Logger(logger=logger) + + return logger_obj diff --git a/src/codesearch/prefix_tree.py b/src/codesearch/prefix_tree.py new file mode 100644 index 0000000..3eb2ea0 --- /dev/null +++ b/src/codesearch/prefix_tree.py @@ -0,0 +1,66 @@ +import json + +import attr + + +@attr.s +class PrefixTree: + root = attr.ib() + + @staticmethod + def initialize(): + root = PrefixTreeNode(value=None) + return PrefixTree(root=root) + + def insert(self, value, key, current=None): + if current is None: + current = self.root + + if not value: + current.mappings.append(key) + return + top = value[0] + rest = value[1:] + + next_child = current.children.get(top) + + if next_child: + self.insert(rest, key, next_child) + else: + new_node = PrefixTreeNode(value=top) + current.children[top] = new_node + self.insert(rest, key, new_node) + + def get(self, value, current=None): + if not current: + current = self.root + if not value: + return current.mappings + + top = value[0] + rest = value[1:] + + next_child = current.children.get(top) + + if next_child: + return self.get(rest, next_child) + + def to_dict(self): + return self.root.to_dict() + + def to_json(self): + return json.dumps(self.to_dict()) + + +@attr.s +class PrefixTreeNode: + value = attr.ib() + mappings = attr.ib(default=attr.Factory(list)) + children = attr.ib(default=attr.Factory(dict)) + + def to_dict(self): + return { + "value": self.value, + "mappings": self.mappings, + "children": [child.to_dict() for child in self.children.values()], + } diff --git a/src/codesearch/process_utils.py b/src/codesearch/process_utils.py new file mode 100644 index 0000000..e5d6613 --- /dev/null +++ b/src/codesearch/process_utils.py @@ -0,0 +1,15 @@ +def chunkify_content(content, chunk_count, chunk_size=None): + if chunk_size is None: + chunk_size = int(len(content) / chunk_count) + chunks = [] + last_boundary = 0 + + for i in range(chunk_count): + if i == chunk_count - 1: + chunks.append(content[last_boundary:]) + else: + chunks.append(content[last_boundary : last_boundary + chunk_size]) + + last_boundary += chunk_size + + return chunks diff --git a/src/codesearch/server.py b/src/codesearch/server.py new file mode 100644 index 0000000..bebf1b7 --- /dev/null +++ b/src/codesearch/server.py @@ -0,0 +1,81 @@ +import json +import socket +import pyinotify +import attr +from codesearch.watcher import WatchHandler +from codesearch.indexer import Indexer +from codesearch.constants import QUERY_STRING_LENGTH +from pathlib import Path +from codesearch.settings import settings + +from codesearch.logger import get_logger + +logger = get_logger(__name__) + + +@attr.s +class Server: + indexer = attr.ib() + watched = attr.ib() + _notifier = attr.ib(default=None) + _socket = attr.ib(default=None) + + def _handle_socket(self, *, socket): + socket.bind((settings.SOCKET_HOST, settings.SOCKET_PORT)) + socket.listen() + + logger.info( + f"Listening on {settings.SOCKET_HOST}:{settings.SOCKET_PORT}", + prefix="Server", + ) + + while True: + conn, _ = socket.accept() + query_string = conn.recv(QUERY_STRING_LENGTH).decode() + logger.info(f"Query string: {query_string}", prefix="Query") + if query_string: + try: + query_results = self.indexer.query(query_string) + response = json.dumps(query_results).encode() + response_length = str(len(response.decode())) + with open(Path(settings.BUFFER_PATH).expanduser(), "wb") as outfile: + outfile.write(response) + conn.sendall(response_length.encode()) + except KeyboardInterrupt: + raise e + except Exception as e: + logger.warning(e) + pass + + def _start_socket(self): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_obj: + self._socket = socket_obj + self._handle_socket(socket=socket_obj) + except Exception as e: + logger.warning(e) + raise e + + def _start_watch(self): + watch_manager = pyinotify.WatchManager() + + for path in self.watched: + logger.info(f"Watching {path}", prefix="Server") + watch_manager.add_watch(path, pyinotify.ALL_EVENTS, rec=True) + + event_handler = WatchHandler(indexer=self.indexer) + notifier = pyinotify.ThreadedNotifier(watch_manager, event_handler) + notifier.start() + self._notifier = notifier + + def run(self): + collected = {} + + self.indexer.index(self.watched) + + try: + self._start_watch() + self._start_socket() + except: + self._socket.close() + self._notifier.stop() diff --git a/src/codesearch/settings.py b/src/codesearch/settings.py new file mode 100644 index 0000000..ea5692a --- /dev/null +++ b/src/codesearch/settings.py @@ -0,0 +1,45 @@ +import json + +from pathlib import Path +import attr + +from .constants import SETTINGS_KEYS + +SETTINGS_PATH = "~/.codesearchrc" + +default_settings = { + "SOCKET_HOST": "127.0.0.1", + "SOCKET_PORT": 65126, + "EXCLUDES": [], + "FILE_TYPES": [], + "SIGNIFICANCE_THRESHOLD": 0, + "WATCHED": [], + "INDEXING_PROCESSES": 4, + "BUFFER_PATH": "~/.codesearchbuffer", +} + + +@attr.s +class Settings: + settings = attr.ib(default=attr.Factory(dict)) + + def from_file(self, path: str): + settings_path = Path(SETTINGS_PATH).expanduser() + + if not settings_path.exists(): + self.settings = default_settings + return + + with open(path, "r") as settings_file: + self.settings = json.load(settings_file) + + def __getattr__(self, key): + if key not in SETTINGS_KEYS: + raise KeyError(f"{key} not a valid settings property") + + return self.settings[key] + + +settings = Settings() + +settings.from_file(Path(SETTINGS_PATH).expanduser()) diff --git a/src/codesearch/test_indexer.py b/src/codesearch/test_indexer.py new file mode 100644 index 0000000..e0d3960 --- /dev/null +++ b/src/codesearch/test_indexer.py @@ -0,0 +1,49 @@ +import pytest + +from .indexer import Indexer + + +@pytest.fixture() +def indexer(): + return Indexer() + + +def test_indexer_builds_trigram_set_for_given_document(indexer): + mock_document = "now that's a doc" + mock_path = "/home/documents/cool_doc" + + indexer.index(path=mock_path, content=mock_document) + + expected_trigrams = [ + "now", + "ow ", + "w t", + " th", + "tha", + "hat", + "at'", + "t's", + "'s ", + "s a", + " a ", + "a d", + " do", + "doc", + ] + + assert indexer.trigrams == {mock_path: set(expected_trigrams)} + + +def test_indexer_preserves_previous_trigram_sets_on_index(indexer): + mock_document_1 = "wow" + mock_document_2 = "woa" + mock_path_1 = "/home" + mock_path_2 = "/somewhere_else" + + indexer.index(path=mock_path_1, content=mock_document_1) + + assert indexer.trigrams == {mock_path_1: set(["wow"])} + + indexer.index(path=mock_path_2, content=mock_document_2) + + assert indexer.trigrams == {mock_path_1: set(["wow"]), mock_path_2: set(["woa"])} diff --git a/src/codesearch/test_prefix_tree.py b/src/codesearch/test_prefix_tree.py new file mode 100644 index 0000000..2d52ac6 --- /dev/null +++ b/src/codesearch/test_prefix_tree.py @@ -0,0 +1,67 @@ +import pytest + +from .prefix_tree import PrefixTree + + +@pytest.fixture +def prefix_tree(): + return PrefixTree.initialize() + + +def test_base_tree_has_a_root_node(prefix_tree, snapshot): + assert prefix_tree.to_dict() == snapshot + + +def test_insert_single_string(prefix_tree, snapshot): + mock_value = "abc" + mock_key = "key_1" + prefix_tree.insert(value=mock_value, key=mock_key) + assert prefix_tree.to_dict() == snapshot + assert prefix_tree.get(value=mock_value) == [mock_key] + + +def test_insert_single_character_(prefix_tree, snapshot): + mock_value = "a" + mock_key = "key_1" + prefix_tree.insert(value=mock_value, key=mock_key) + assert prefix_tree.to_dict() == snapshot + assert prefix_tree.get(value=mock_value) == [mock_key] + + +def test_insert_overlapping_strings(prefix_tree, snapshot): + mock_value_1 = "abcd" + mock_key_1 = "key_1" + mock_value_2 = "abce" + mock_key_2 = "key_2" + prefix_tree.insert(value=mock_value_1, key=mock_key_1) + prefix_tree.insert(value=mock_value_2, key=mock_key_2) + assert prefix_tree.to_dict() == snapshot + assert prefix_tree.get(value=mock_value_1) == [mock_key_1] + assert prefix_tree.get(value=mock_value_2) == [mock_key_2] + + +def test_insert_multiple_keys_same_string(prefix_tree, snapshot): + mock_value = "abcd" + mock_key_1 = "key_1" + mock_key_2 = "key_2" + prefix_tree.insert(value=mock_value, key=mock_key_1) + prefix_tree.insert(value=mock_value, key=mock_key_2) + assert prefix_tree.to_dict() == snapshot + assert prefix_tree.get(value=mock_value) == [mock_key_1, mock_key_2] + + +def test_insert_strings_subsets_of_each_other(prefix_tree, snapshot): + mock_value_1 = "abcd" + mock_key_1 = "key_1" + mock_value_2 = "abc" + mock_key_2 = "key_2" + prefix_tree.insert(value=mock_value_1, key=mock_key_1) + prefix_tree.insert(value=mock_value_2, key=mock_key_2) + assert prefix_tree.to_dict() == snapshot + assert prefix_tree.get(value=mock_value_1) == [mock_key_1] + assert prefix_tree.get(value=mock_value_2) == [mock_key_2] + + +def test_serializes_to_json(prefix_tree, snapshot): + prefix_tree.insert(value="abcd", key="key_1") + assert prefix_tree.to_json() == snapshot diff --git a/src/codesearch/trigram_index.py b/src/codesearch/trigram_index.py new file mode 100644 index 0000000..a147544 --- /dev/null +++ b/src/codesearch/trigram_index.py @@ -0,0 +1,47 @@ +from typing import List, Optional + +import attr +from .settings import settings +from .base import IndexBase +from .prefix_tree import PrefixTree + + +@attr.s +class TrigramIndex(IndexBase): + _threshold = attr.ib(default=settings.SIGNIFICANCE_THRESHOLD) + _tree = attr.ib(attr.Factory(PrefixTree.initialize)) + + def index(self, uid, key: str, content: str, trigrams): + if content: + trigrams = TrigramIndex.trigramize(content) + + for trigram in trigrams: + self._tree.insert(trigram, uid) + + def query(self, query: str, haystack: Optional[List[str]] = None) -> List[str]: + query_trigrams = TrigramIndex.trigramize(query) + results = {} + + for trigram in query_trigrams: + result_set = self._tree.get(trigram) + if result_set: + results[trigram] = result_set + + matches = {} + + for result in results: + for doc in results[result]: + matches[doc] = matches.get(doc, 0) + 1 + + significant_results = [] + for uid, occurrences in matches.items(): + score = occurrences / len(query_trigrams) + if score >= self._threshold: + significant_results.append((f"document:{uid}", score)) + + significant_results.sort(reverse=True, key=lambda x: x[0]) + return significant_results + + @staticmethod + def trigramize(content: str) -> List[str]: + return {content[pos : pos + 3].lower() for pos in range(len(content) - 2)} diff --git a/src/codesearch/watcher.py b/src/codesearch/watcher.py new file mode 100644 index 0000000..92c25db --- /dev/null +++ b/src/codesearch/watcher.py @@ -0,0 +1,14 @@ +import pyinotify +import attr + +from .logger import get_logger + +logger = get_logger(__name__) + + +@attr.s +class WatchHandler(pyinotify.ProcessEvent): + indexer = attr.ib() + + def process_IN_MODIFY(self, event): + self.indexer.index([event.pathname]) diff --git a/src/setup.py b/src/setup.py new file mode 100644 index 0000000..6d1bf11 --- /dev/null +++ b/src/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="codesearch", + version="0.1", + packages=["codesearch"], + install_requires=["pyinotify", "attr"], + entry_points={"console_scripts": ["codesearch=codesearch.cli:main"]}, +) diff --git a/tasks.py b/tasks.py new file mode 100644 index 0000000..00ca26b --- /dev/null +++ b/tasks.py @@ -0,0 +1,6 @@ +from invoke import task + + +@task +def lint(ctx): + ctx.run("black *.py src")