wip: functional version
This commit is contained in:
parent
73d90beb9a
commit
59795a5dec
15 changed files with 411 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,6 +2,7 @@
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
*.sw[a-z]
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
21
requirements.txt
Normal file
21
requirements.txt
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
appdirs==1.4.4
|
||||||
|
attr==0.3.1
|
||||||
|
attrs==20.2.0
|
||||||
|
black==20.8b1
|
||||||
|
click==7.1.2
|
||||||
|
iniconfig==1.0.1
|
||||||
|
invoke==1.4.1
|
||||||
|
more-itertools==8.5.0
|
||||||
|
mypy-extensions==0.4.3
|
||||||
|
packaging==20.4
|
||||||
|
pathspec==0.8.0
|
||||||
|
pluggy==0.13.1
|
||||||
|
py==1.9.0
|
||||||
|
pyinotify==0.9.6
|
||||||
|
pyparsing==2.4.7
|
||||||
|
pytest==6.0.2
|
||||||
|
regex==2020.7.14
|
||||||
|
six==1.15.0
|
||||||
|
toml==0.10.1
|
||||||
|
typed-ast==1.4.1
|
||||||
|
typing-extensions==3.7.4.3
|
15
script/bootstrap
Normal file
15
script/bootstrap
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
VENV=codesearch.venv
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
# Bootstrapping sets up the Python 3.8 venv that allows the use #
|
||||||
|
# of the invoke commands. #
|
||||||
|
#################################################################
|
||||||
|
|
||||||
|
{
|
||||||
|
pyenv virtualenv-delete -f $VENV
|
||||||
|
pyenv virtualenv $VENV &&
|
||||||
|
pyenv activate $VENV &&
|
||||||
|
python -m pip install -U pip &&
|
||||||
|
pip install -r requirements.txt &&
|
||||||
|
echo "✨ Good to go! ✨"
|
||||||
|
}
|
11
src/base.py
Normal file
11
src/base.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class IndexerBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def discover(self, path: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def index(self, path: str):
|
||||||
|
pass
|
22
src/cli.py
Normal file
22
src/cli.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from server import Server
|
||||||
|
from indexer import Indexer
|
||||||
|
from client import search
|
||||||
|
import settings
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("command")
|
||||||
|
parser.add_argument("--q", required=False)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "start":
|
||||||
|
server = Server(
|
||||||
|
indexer=Indexer(trigram_threshold=settings.SIGNIFICANCE_THRESHOLD),
|
||||||
|
watched=[".."],
|
||||||
|
)
|
||||||
|
server.run()
|
||||||
|
elif args.command == "search":
|
||||||
|
search(args.q)
|
28
src/client.py
Normal file
28
src/client.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import socket
|
||||||
|
import json
|
||||||
|
|
||||||
|
import settings
|
||||||
|
|
||||||
|
from colors import highlight
|
||||||
|
|
||||||
|
|
||||||
|
def search(query):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.connect((settings.SOCKET_HOST, settings.SOCKET_PORT))
|
||||||
|
s.sendall(query.encode())
|
||||||
|
length = int(s.recv(4).decode())
|
||||||
|
results = json.loads(s.recv(length).decode())
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
with open(result[0], "r") as infile:
|
||||||
|
highlighted_text = infile.read()[result[1] : result[2]].replace(
|
||||||
|
query, highlight(query)
|
||||||
|
)
|
||||||
|
line_number = result[3]
|
||||||
|
print(highlight(result[0]))
|
||||||
|
for l in highlighted_text.split("\n"):
|
||||||
|
print(f"{highlight(line_number)} {l}")
|
||||||
|
line_number += 1
|
||||||
|
print("\n\n")
|
||||||
|
|
||||||
|
s.close()
|
4
src/colors.py
Normal file
4
src/colors.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
def highlight(text):
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
ENDC = "\033[0m"
|
||||||
|
return f"{GREEN}{text}{ENDC}"
|
1
src/constants.py
Normal file
1
src/constants.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
QUERY_STRING_LENGTH = 1024
|
135
src/indexer.py
Normal file
135
src/indexer.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
from base import IndexerBase
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
import re
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
import settings
|
||||||
|
|
||||||
|
|
||||||
|
from logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class Indexer(IndexerBase):
|
||||||
|
_trigrams = attr.ib(default=attr.Factory(dict))
|
||||||
|
_collected = attr.ib(default=attr.Factory(dict))
|
||||||
|
_lines = attr.ib(default=attr.Factory(dict))
|
||||||
|
trigram_threshold = attr.ib(default=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trigrams(self):
|
||||||
|
return dict(self._trigrams)
|
||||||
|
|
||||||
|
def discover(self, path_root: str) -> Dict[str, str]:
|
||||||
|
logger.info(f"Collecting {path_root}")
|
||||||
|
collected = {}
|
||||||
|
|
||||||
|
if any([x in path_root for x in settings.excludes]):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
current = Path(path_root)
|
||||||
|
|
||||||
|
if current.is_dir():
|
||||||
|
for child_path in current.iterdir():
|
||||||
|
collected.update(self.discover(str(child_path)))
|
||||||
|
|
||||||
|
self._collected.update(collected)
|
||||||
|
|
||||||
|
return dict(self._collected)
|
||||||
|
if current.suffix not in settings.types:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(current, "r") as infile:
|
||||||
|
self._collected[str(current.resolve())] = infile.read()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return dict(self._collected)
|
||||||
|
|
||||||
|
def index(self, path: str, content: str):
|
||||||
|
p = Path(path)
|
||||||
|
|
||||||
|
self._trigrams[path] = set()
|
||||||
|
|
||||||
|
for idx in range(len(content) - 2):
|
||||||
|
self._trigrams[path].add(content[idx : idx + 3])
|
||||||
|
|
||||||
|
self._lines[path] = {}
|
||||||
|
|
||||||
|
content = self._collected[path]
|
||||||
|
current, count = 0, 0
|
||||||
|
for line in self._collected[path].split("\n"):
|
||||||
|
self._lines[path][count] = (current, current + len(line))
|
||||||
|
current += len(line)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
def query(self, query: str):
|
||||||
|
trigram_results = self.search_trigrams(query)
|
||||||
|
confirmed = self.search_content(query, trigram_results)
|
||||||
|
|
||||||
|
return confirmed
|
||||||
|
|
||||||
|
def find_closest_line(self, path, index, offset=0):
|
||||||
|
content = self._lines[path]
|
||||||
|
|
||||||
|
for l in content:
|
||||||
|
if content[l][0] <= index <= content[l][1]:
|
||||||
|
return l
|
||||||
|
|
||||||
|
logger.error(f"{path} {index}")
|
||||||
|
logger.error(content)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def search_content(self, query: str, leads):
|
||||||
|
confirmed = []
|
||||||
|
uniques = 0
|
||||||
|
for lead in leads:
|
||||||
|
lead_content = self._collected[lead[1]]
|
||||||
|
results = re.finditer(query, lead_content)
|
||||||
|
hits_in_lead = []
|
||||||
|
for hit in results:
|
||||||
|
start_line = self.find_closest_line(lead[1], hit.start())
|
||||||
|
end_line = self.find_closest_line(lead[1], hit.end())
|
||||||
|
|
||||||
|
start_line_range = max(0, start_line - 5)
|
||||||
|
end_line_range = min(len(self._lines[lead[1]]) - 1, end_line + 5)
|
||||||
|
|
||||||
|
hits_in_lead.append(
|
||||||
|
(
|
||||||
|
lead[1],
|
||||||
|
self._lines[lead[1]][start_line_range][0],
|
||||||
|
self._lines[lead[1]][end_line_range][1],
|
||||||
|
start_line_range,
|
||||||
|
end_line_range,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if hits_in_lead:
|
||||||
|
confirmed.extend(hits_in_lead)
|
||||||
|
uniques += 1
|
||||||
|
|
||||||
|
logger.info(f"{len(confirmed)} hits in {uniques} files.")
|
||||||
|
return confirmed
|
||||||
|
|
||||||
|
def search_trigrams(self, query: str):
|
||||||
|
query_trigrams = [query[idx : idx + 3] for idx in range(len(query) - 2)]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for item in self.trigrams:
|
||||||
|
shared = self.trigrams[item].intersection(query_trigrams)
|
||||||
|
ratio = len(shared) / len(query_trigrams)
|
||||||
|
if ratio < self.trigram_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append((ratio, item, list(shared)))
|
||||||
|
|
||||||
|
results.sort(reverse=True, key=lambda x: x[0])
|
||||||
|
|
||||||
|
logger.info(f"Narrowed down to {len(results)} files via trigram search")
|
||||||
|
|
||||||
|
return results
|
11
src/logger.py
Normal file
11
src/logger.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
84
src/server.py
Normal file
84
src/server.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import pyinotify
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from watcher import WatchHandler
|
||||||
|
from indexer import Indexer
|
||||||
|
from constants import QUERY_STRING_LENGTH
|
||||||
|
|
||||||
|
import settings
|
||||||
|
|
||||||
|
from logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class Server:
|
||||||
|
indexer = attr.ib()
|
||||||
|
watched = attr.ib()
|
||||||
|
_notifier = attr.ib(default=None)
|
||||||
|
_socket = attr.ib(default=None)
|
||||||
|
|
||||||
|
def _handle_socket(self, *, socket):
|
||||||
|
socket.bind((settings.SOCKET_HOST, settings.SOCKET_PORT))
|
||||||
|
socket.listen()
|
||||||
|
|
||||||
|
logger.info(f"Listening on ${settings.SOCKET_HOST}:${settings.SOCKET_PORT}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
conn, _ = socket.accept()
|
||||||
|
query_string = conn.recv(QUERY_STRING_LENGTH)
|
||||||
|
logger.info(f"Query: {query_string}")
|
||||||
|
if query_string:
|
||||||
|
try:
|
||||||
|
query_results = self.indexer.query(query_string.decode())
|
||||||
|
response = json.dumps(query_results).encode()
|
||||||
|
response_length = str(len(response))
|
||||||
|
conn.sendall(response_length.encode())
|
||||||
|
conn.sendall(response)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
|
||||||
|
def _start_socket(self):
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_obj:
|
||||||
|
self._socket = socket_obj
|
||||||
|
self._handle_socket(socket=socket_obj)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _start_watch(self):
|
||||||
|
watch_manager = pyinotify.WatchManager()
|
||||||
|
|
||||||
|
for path in self.watched:
|
||||||
|
logger.info(f"Watching ${path}")
|
||||||
|
watch_manager.add_watch(path, pyinotify.ALL_EVENTS, rec=True)
|
||||||
|
|
||||||
|
event_handler = WatchHandler(indexer=self.indexer)
|
||||||
|
notifier = pyinotify.ThreadedNotifier(watch_manager, event_handler)
|
||||||
|
notifier.start()
|
||||||
|
self._notifier = notifier
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
collected = {}
|
||||||
|
|
||||||
|
for watched_path in self.watched:
|
||||||
|
logger.info(f"Collecting files from ${watched_path}")
|
||||||
|
collected.update(self.indexer.discover(watched_path))
|
||||||
|
|
||||||
|
for c in collected:
|
||||||
|
logger.info(f"Indexing ${c}")
|
||||||
|
self.indexer.index(c, collected[c])
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._start_watch()
|
||||||
|
self._start_socket()
|
||||||
|
except:
|
||||||
|
self._socket.close()
|
||||||
|
self._notifier.stop()
|
6
src/settings.py
Normal file
6
src/settings.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
SOCKET_HOST = "127.0.0.1"
|
||||||
|
SOCKET_PORT = 65126
|
||||||
|
|
||||||
|
excludes = ["node_modules", ".git", ".venv"]
|
||||||
|
types = [".py", ".js"]
|
||||||
|
SIGNIFICANCE_THRESHOLD = 0.7
|
49
src/test_indexer.py
Normal file
49
src/test_indexer.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from indexer import Indexer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def indexer():
|
||||||
|
return Indexer()
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexer_builds_trigram_set_for_given_document(indexer):
|
||||||
|
mock_document = "now that's a doc"
|
||||||
|
mock_path = "/home/documents/cool_doc"
|
||||||
|
|
||||||
|
indexer.index(path=mock_path, content=mock_document)
|
||||||
|
|
||||||
|
expected_trigrams = [
|
||||||
|
"now",
|
||||||
|
"ow ",
|
||||||
|
"w t",
|
||||||
|
" th",
|
||||||
|
"tha",
|
||||||
|
"hat",
|
||||||
|
"at'",
|
||||||
|
"t's",
|
||||||
|
"'s ",
|
||||||
|
"s a",
|
||||||
|
" a ",
|
||||||
|
"a d",
|
||||||
|
" do",
|
||||||
|
"doc",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert indexer.trigrams == {mock_path: set(expected_trigrams)}
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexer_preserves_previous_trigram_sets_on_index(indexer):
|
||||||
|
mock_document_1 = "wow"
|
||||||
|
mock_document_2 = "woa"
|
||||||
|
mock_path_1 = "/home"
|
||||||
|
mock_path_2 = "/somewhere_else"
|
||||||
|
|
||||||
|
indexer.index(path=mock_path_1, content=mock_document_1)
|
||||||
|
|
||||||
|
assert indexer.trigrams == {mock_path_1: set(["wow"])}
|
||||||
|
|
||||||
|
indexer.index(path=mock_path_2, content=mock_document_2)
|
||||||
|
|
||||||
|
assert indexer.trigrams == {mock_path_1: set(["wow"]), mock_path_2: set(["woa"])}
|
17
src/watcher.py
Normal file
17
src/watcher.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
import pyinotify
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class WatchHandler(pyinotify.ProcessEvent):
|
||||||
|
indexer = attr.ib()
|
||||||
|
|
||||||
|
def process_IN_MODIFY(self, event):
|
||||||
|
c = self.indexer.discover(event.pathname)
|
||||||
|
if c.get(event.pathname):
|
||||||
|
logger.info(f"Reindexed {event.pathname}")
|
||||||
|
self.indexer.index(event.pathname, c[event.pathname])
|
6
tasks.py
Normal file
6
tasks.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from invoke import task
|
||||||
|
|
||||||
|
|
||||||
|
@task
|
||||||
|
def lint(ctx):
|
||||||
|
ctx.run("black *.py src")
|
Reference in a new issue