wip: functional version

This commit is contained in:
Marc Cataford 2020-09-27 00:11:08 -04:00
parent 73d90beb9a
commit 59795a5dec
15 changed files with 411 additions and 0 deletions

1
.gitignore vendored
View file

@ -2,6 +2,7 @@
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
*.sw[a-z]
# C extensions # C extensions
*.so *.so

21
requirements.txt Normal file
View file

@ -0,0 +1,21 @@
appdirs==1.4.4
attr==0.3.1
attrs==20.2.0
black==20.8b1
click==7.1.2
iniconfig==1.0.1
invoke==1.4.1
more-itertools==8.5.0
mypy-extensions==0.4.3
packaging==20.4
pathspec==0.8.0
pluggy==0.13.1
py==1.9.0
pyinotify==0.9.6
pyparsing==2.4.7
pytest==6.0.2
regex==2020.7.14
six==1.15.0
toml==0.10.1
typed-ast==1.4.1
typing-extensions==3.7.4.3

15
script/bootstrap Normal file
View file

@ -0,0 +1,15 @@
VENV=codesearch.venv
#################################################################
# Bootstrapping sets up the Python 3.8 venv that allows the use #
# of the invoke commands. #
#################################################################
{
pyenv virtualenv-delete -f $VENV
pyenv virtualenv $VENV &&
pyenv activate $VENV &&
python -m pip install -U pip &&
pip install -r requirements.txt &&
echo "✨ Good to go! ✨"
}

11
src/base.py Normal file
View file

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class IndexerBase(ABC):
@abstractmethod
def discover(self, path: str):
pass
@abstractmethod
def index(self, path: str):
pass

22
src/cli.py Normal file
View file

@ -0,0 +1,22 @@
import argparse
from server import Server
from indexer import Indexer
from client import search
import settings
parser = argparse.ArgumentParser()
parser.add_argument("command")
parser.add_argument("--q", required=False)
args = parser.parse_args()
if args.command == "start":
server = Server(
indexer=Indexer(trigram_threshold=settings.SIGNIFICANCE_THRESHOLD),
watched=[".."],
)
server.run()
elif args.command == "search":
search(args.q)

28
src/client.py Normal file
View file

@ -0,0 +1,28 @@
import socket
import json
import settings
from colors import highlight
def search(query):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((settings.SOCKET_HOST, settings.SOCKET_PORT))
s.sendall(query.encode())
length = int(s.recv(4).decode())
results = json.loads(s.recv(length).decode())
for result in results:
with open(result[0], "r") as infile:
highlighted_text = infile.read()[result[1] : result[2]].replace(
query, highlight(query)
)
line_number = result[3]
print(highlight(result[0]))
for l in highlighted_text.split("\n"):
print(f"{highlight(line_number)} {l}")
line_number += 1
print("\n\n")
s.close()

4
src/colors.py Normal file
View file

@ -0,0 +1,4 @@
def highlight(text):
GREEN = "\033[92m"
ENDC = "\033[0m"
return f"{GREEN}{text}{ENDC}"

1
src/constants.py Normal file
View file

@ -0,0 +1 @@
QUERY_STRING_LENGTH = 1024

135
src/indexer.py Normal file
View file

@ -0,0 +1,135 @@
from base import IndexerBase
from pathlib import Path
from typing import Dict
import re
import attr
import settings
from logger import get_logger
logger = get_logger(__name__)
@attr.s
class Indexer(IndexerBase):
_trigrams = attr.ib(default=attr.Factory(dict))
_collected = attr.ib(default=attr.Factory(dict))
_lines = attr.ib(default=attr.Factory(dict))
trigram_threshold = attr.ib(default=0)
@property
def trigrams(self):
return dict(self._trigrams)
def discover(self, path_root: str) -> Dict[str, str]:
logger.info(f"Collecting {path_root}")
collected = {}
if any([x in path_root for x in settings.excludes]):
return {}
current = Path(path_root)
if current.is_dir():
for child_path in current.iterdir():
collected.update(self.discover(str(child_path)))
self._collected.update(collected)
return dict(self._collected)
if current.suffix not in settings.types:
return {}
try:
with open(current, "r") as infile:
self._collected[str(current.resolve())] = infile.read()
except:
pass
return dict(self._collected)
def index(self, path: str, content: str):
p = Path(path)
self._trigrams[path] = set()
for idx in range(len(content) - 2):
self._trigrams[path].add(content[idx : idx + 3])
self._lines[path] = {}
content = self._collected[path]
current, count = 0, 0
for line in self._collected[path].split("\n"):
self._lines[path][count] = (current, current + len(line))
current += len(line)
count += 1
def query(self, query: str):
trigram_results = self.search_trigrams(query)
confirmed = self.search_content(query, trigram_results)
return confirmed
def find_closest_line(self, path, index, offset=0):
content = self._lines[path]
for l in content:
if content[l][0] <= index <= content[l][1]:
return l
logger.error(f"{path} {index}")
logger.error(content)
return 0
def search_content(self, query: str, leads):
confirmed = []
uniques = 0
for lead in leads:
lead_content = self._collected[lead[1]]
results = re.finditer(query, lead_content)
hits_in_lead = []
for hit in results:
start_line = self.find_closest_line(lead[1], hit.start())
end_line = self.find_closest_line(lead[1], hit.end())
start_line_range = max(0, start_line - 5)
end_line_range = min(len(self._lines[lead[1]]) - 1, end_line + 5)
hits_in_lead.append(
(
lead[1],
self._lines[lead[1]][start_line_range][0],
self._lines[lead[1]][end_line_range][1],
start_line_range,
end_line_range,
)
)
if hits_in_lead:
confirmed.extend(hits_in_lead)
uniques += 1
logger.info(f"{len(confirmed)} hits in {uniques} files.")
return confirmed
def search_trigrams(self, query: str):
query_trigrams = [query[idx : idx + 3] for idx in range(len(query) - 2)]
results = []
for item in self.trigrams:
shared = self.trigrams[item].intersection(query_trigrams)
ratio = len(shared) / len(query_trigrams)
if ratio < self.trigram_threshold:
continue
results.append((ratio, item, list(shared)))
results.sort(reverse=True, key=lambda x: x[0])
logger.info(f"Narrowed down to {len(results)} files via trigram search")
return results

11
src/logger.py Normal file
View file

@ -0,0 +1,11 @@
import logging
import sys
def get_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
return logger

84
src/server.py Normal file
View file

@ -0,0 +1,84 @@
import json
import socket
import pyinotify
import attr
from watcher import WatchHandler
from indexer import Indexer
from constants import QUERY_STRING_LENGTH
import settings
from logger import get_logger
logger = get_logger(__name__)
@attr.s
class Server:
indexer = attr.ib()
watched = attr.ib()
_notifier = attr.ib(default=None)
_socket = attr.ib(default=None)
def _handle_socket(self, *, socket):
socket.bind((settings.SOCKET_HOST, settings.SOCKET_PORT))
socket.listen()
logger.info(f"Listening on ${settings.SOCKET_HOST}:${settings.SOCKET_PORT}")
while True:
conn, _ = socket.accept()
query_string = conn.recv(QUERY_STRING_LENGTH)
logger.info(f"Query: {query_string}")
if query_string:
try:
query_results = self.indexer.query(query_string.decode())
response = json.dumps(query_results).encode()
response_length = str(len(response))
conn.sendall(response_length.encode())
conn.sendall(response)
except KeyboardInterrupt:
raise e
except Exception as e:
logger.exception(e)
def _start_socket(self):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_obj:
self._socket = socket_obj
self._handle_socket(socket=socket_obj)
except Exception as e:
logger.exception(e)
raise e
def _start_watch(self):
watch_manager = pyinotify.WatchManager()
for path in self.watched:
logger.info(f"Watching ${path}")
watch_manager.add_watch(path, pyinotify.ALL_EVENTS, rec=True)
event_handler = WatchHandler(indexer=self.indexer)
notifier = pyinotify.ThreadedNotifier(watch_manager, event_handler)
notifier.start()
self._notifier = notifier
def run(self):
collected = {}
for watched_path in self.watched:
logger.info(f"Collecting files from ${watched_path}")
collected.update(self.indexer.discover(watched_path))
for c in collected:
logger.info(f"Indexing ${c}")
self.indexer.index(c, collected[c])
try:
self._start_watch()
self._start_socket()
except:
self._socket.close()
self._notifier.stop()

6
src/settings.py Normal file
View file

@ -0,0 +1,6 @@
SOCKET_HOST = "127.0.0.1"
SOCKET_PORT = 65126
excludes = ["node_modules", ".git", ".venv"]
types = [".py", ".js"]
SIGNIFICANCE_THRESHOLD = 0.7

49
src/test_indexer.py Normal file
View file

@ -0,0 +1,49 @@
import pytest
from indexer import Indexer
@pytest.fixture()
def indexer():
return Indexer()
def test_indexer_builds_trigram_set_for_given_document(indexer):
mock_document = "now that's a doc"
mock_path = "/home/documents/cool_doc"
indexer.index(path=mock_path, content=mock_document)
expected_trigrams = [
"now",
"ow ",
"w t",
" th",
"tha",
"hat",
"at'",
"t's",
"'s ",
"s a",
" a ",
"a d",
" do",
"doc",
]
assert indexer.trigrams == {mock_path: set(expected_trigrams)}
def test_indexer_preserves_previous_trigram_sets_on_index(indexer):
mock_document_1 = "wow"
mock_document_2 = "woa"
mock_path_1 = "/home"
mock_path_2 = "/somewhere_else"
indexer.index(path=mock_path_1, content=mock_document_1)
assert indexer.trigrams == {mock_path_1: set(["wow"])}
indexer.index(path=mock_path_2, content=mock_document_2)
assert indexer.trigrams == {mock_path_1: set(["wow"]), mock_path_2: set(["woa"])}

17
src/watcher.py Normal file
View file

@ -0,0 +1,17 @@
import pyinotify
import attr
from logger import get_logger
logger = get_logger(__name__)
@attr.s
class WatchHandler(pyinotify.ProcessEvent):
indexer = attr.ib()
def process_IN_MODIFY(self, event):
c = self.indexer.discover(event.pathname)
if c.get(event.pathname):
logger.info(f"Reindexed {event.pathname}")
self.indexer.index(event.pathname, c[event.pathname])

6
tasks.py Normal file
View file

@ -0,0 +1,6 @@
from invoke import task
@task
def lint(ctx):
ctx.run("black *.py src")