wip: functional version
This commit is contained in:
parent
73d90beb9a
commit
59795a5dec
15 changed files with 411 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,6 +2,7 @@
|
|||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.sw[a-z]
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
|
21
requirements.txt
Normal file
21
requirements.txt
Normal file
|
@ -0,0 +1,21 @@
|
|||
appdirs==1.4.4
|
||||
attr==0.3.1
|
||||
attrs==20.2.0
|
||||
black==20.8b1
|
||||
click==7.1.2
|
||||
iniconfig==1.0.1
|
||||
invoke==1.4.1
|
||||
more-itertools==8.5.0
|
||||
mypy-extensions==0.4.3
|
||||
packaging==20.4
|
||||
pathspec==0.8.0
|
||||
pluggy==0.13.1
|
||||
py==1.9.0
|
||||
pyinotify==0.9.6
|
||||
pyparsing==2.4.7
|
||||
pytest==6.0.2
|
||||
regex==2020.7.14
|
||||
six==1.15.0
|
||||
toml==0.10.1
|
||||
typed-ast==1.4.1
|
||||
typing-extensions==3.7.4.3
|
15
script/bootstrap
Normal file
15
script/bootstrap
Normal file
|
@ -0,0 +1,15 @@
|
|||
VENV=codesearch.venv
|
||||
|
||||
#################################################################
|
||||
# Bootstrapping sets up the Python 3.8 venv that allows the use #
|
||||
# of the invoke commands. #
|
||||
#################################################################
|
||||
|
||||
{
|
||||
pyenv virtualenv-delete -f $VENV
|
||||
pyenv virtualenv $VENV &&
|
||||
pyenv activate $VENV &&
|
||||
python -m pip install -U pip &&
|
||||
pip install -r requirements.txt &&
|
||||
echo "✨ Good to go! ✨"
|
||||
}
|
11
src/base.py
Normal file
11
src/base.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IndexerBase(ABC):
|
||||
@abstractmethod
|
||||
def discover(self, path: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def index(self, path: str):
|
||||
pass
|
22
src/cli.py
Normal file
22
src/cli.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import argparse
|
||||
|
||||
from server import Server
|
||||
from indexer import Indexer
|
||||
from client import search
|
||||
import settings
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("command")
|
||||
parser.add_argument("--q", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "start":
|
||||
server = Server(
|
||||
indexer=Indexer(trigram_threshold=settings.SIGNIFICANCE_THRESHOLD),
|
||||
watched=[".."],
|
||||
)
|
||||
server.run()
|
||||
elif args.command == "search":
|
||||
search(args.q)
|
28
src/client.py
Normal file
28
src/client.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import socket
|
||||
import json
|
||||
|
||||
import settings
|
||||
|
||||
from colors import highlight
|
||||
|
||||
|
||||
def search(query):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.connect((settings.SOCKET_HOST, settings.SOCKET_PORT))
|
||||
s.sendall(query.encode())
|
||||
length = int(s.recv(4).decode())
|
||||
results = json.loads(s.recv(length).decode())
|
||||
|
||||
for result in results:
|
||||
with open(result[0], "r") as infile:
|
||||
highlighted_text = infile.read()[result[1] : result[2]].replace(
|
||||
query, highlight(query)
|
||||
)
|
||||
line_number = result[3]
|
||||
print(highlight(result[0]))
|
||||
for l in highlighted_text.split("\n"):
|
||||
print(f"{highlight(line_number)} {l}")
|
||||
line_number += 1
|
||||
print("\n\n")
|
||||
|
||||
s.close()
|
4
src/colors.py
Normal file
4
src/colors.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
def highlight(text):
|
||||
GREEN = "\033[92m"
|
||||
ENDC = "\033[0m"
|
||||
return f"{GREEN}{text}{ENDC}"
|
1
src/constants.py
Normal file
1
src/constants.py
Normal file
|
@ -0,0 +1 @@
|
|||
QUERY_STRING_LENGTH = 1024
|
135
src/indexer.py
Normal file
135
src/indexer.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
from base import IndexerBase
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import re
|
||||
|
||||
import attr
|
||||
|
||||
import settings
|
||||
|
||||
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Indexer(IndexerBase):
|
||||
_trigrams = attr.ib(default=attr.Factory(dict))
|
||||
_collected = attr.ib(default=attr.Factory(dict))
|
||||
_lines = attr.ib(default=attr.Factory(dict))
|
||||
trigram_threshold = attr.ib(default=0)
|
||||
|
||||
@property
|
||||
def trigrams(self):
|
||||
return dict(self._trigrams)
|
||||
|
||||
def discover(self, path_root: str) -> Dict[str, str]:
|
||||
logger.info(f"Collecting {path_root}")
|
||||
collected = {}
|
||||
|
||||
if any([x in path_root for x in settings.excludes]):
|
||||
return {}
|
||||
|
||||
current = Path(path_root)
|
||||
|
||||
if current.is_dir():
|
||||
for child_path in current.iterdir():
|
||||
collected.update(self.discover(str(child_path)))
|
||||
|
||||
self._collected.update(collected)
|
||||
|
||||
return dict(self._collected)
|
||||
if current.suffix not in settings.types:
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(current, "r") as infile:
|
||||
self._collected[str(current.resolve())] = infile.read()
|
||||
except:
|
||||
pass
|
||||
|
||||
return dict(self._collected)
|
||||
|
||||
def index(self, path: str, content: str):
|
||||
p = Path(path)
|
||||
|
||||
self._trigrams[path] = set()
|
||||
|
||||
for idx in range(len(content) - 2):
|
||||
self._trigrams[path].add(content[idx : idx + 3])
|
||||
|
||||
self._lines[path] = {}
|
||||
|
||||
content = self._collected[path]
|
||||
current, count = 0, 0
|
||||
for line in self._collected[path].split("\n"):
|
||||
self._lines[path][count] = (current, current + len(line))
|
||||
current += len(line)
|
||||
count += 1
|
||||
|
||||
def query(self, query: str):
|
||||
trigram_results = self.search_trigrams(query)
|
||||
confirmed = self.search_content(query, trigram_results)
|
||||
|
||||
return confirmed
|
||||
|
||||
def find_closest_line(self, path, index, offset=0):
|
||||
content = self._lines[path]
|
||||
|
||||
for l in content:
|
||||
if content[l][0] <= index <= content[l][1]:
|
||||
return l
|
||||
|
||||
logger.error(f"{path} {index}")
|
||||
logger.error(content)
|
||||
return 0
|
||||
|
||||
def search_content(self, query: str, leads):
|
||||
confirmed = []
|
||||
uniques = 0
|
||||
for lead in leads:
|
||||
lead_content = self._collected[lead[1]]
|
||||
results = re.finditer(query, lead_content)
|
||||
hits_in_lead = []
|
||||
for hit in results:
|
||||
start_line = self.find_closest_line(lead[1], hit.start())
|
||||
end_line = self.find_closest_line(lead[1], hit.end())
|
||||
|
||||
start_line_range = max(0, start_line - 5)
|
||||
end_line_range = min(len(self._lines[lead[1]]) - 1, end_line + 5)
|
||||
|
||||
hits_in_lead.append(
|
||||
(
|
||||
lead[1],
|
||||
self._lines[lead[1]][start_line_range][0],
|
||||
self._lines[lead[1]][end_line_range][1],
|
||||
start_line_range,
|
||||
end_line_range,
|
||||
)
|
||||
)
|
||||
|
||||
if hits_in_lead:
|
||||
confirmed.extend(hits_in_lead)
|
||||
uniques += 1
|
||||
|
||||
logger.info(f"{len(confirmed)} hits in {uniques} files.")
|
||||
return confirmed
|
||||
|
||||
def search_trigrams(self, query: str):
|
||||
query_trigrams = [query[idx : idx + 3] for idx in range(len(query) - 2)]
|
||||
results = []
|
||||
|
||||
for item in self.trigrams:
|
||||
shared = self.trigrams[item].intersection(query_trigrams)
|
||||
ratio = len(shared) / len(query_trigrams)
|
||||
if ratio < self.trigram_threshold:
|
||||
continue
|
||||
|
||||
results.append((ratio, item, list(shared)))
|
||||
|
||||
results.sort(reverse=True, key=lambda x: x[0])
|
||||
|
||||
logger.info(f"Narrowed down to {len(results)} files via trigram search")
|
||||
|
||||
return results
|
11
src/logger.py
Normal file
11
src/logger.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
84
src/server.py
Normal file
84
src/server.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
|
||||
import socket
|
||||
import pyinotify
|
||||
import attr
|
||||
|
||||
from watcher import WatchHandler
|
||||
from indexer import Indexer
|
||||
from constants import QUERY_STRING_LENGTH
|
||||
|
||||
import settings
|
||||
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Server:
|
||||
indexer = attr.ib()
|
||||
watched = attr.ib()
|
||||
_notifier = attr.ib(default=None)
|
||||
_socket = attr.ib(default=None)
|
||||
|
||||
def _handle_socket(self, *, socket):
|
||||
socket.bind((settings.SOCKET_HOST, settings.SOCKET_PORT))
|
||||
socket.listen()
|
||||
|
||||
logger.info(f"Listening on ${settings.SOCKET_HOST}:${settings.SOCKET_PORT}")
|
||||
|
||||
while True:
|
||||
conn, _ = socket.accept()
|
||||
query_string = conn.recv(QUERY_STRING_LENGTH)
|
||||
logger.info(f"Query: {query_string}")
|
||||
if query_string:
|
||||
try:
|
||||
query_results = self.indexer.query(query_string.decode())
|
||||
response = json.dumps(query_results).encode()
|
||||
response_length = str(len(response))
|
||||
conn.sendall(response_length.encode())
|
||||
conn.sendall(response)
|
||||
except KeyboardInterrupt:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _start_socket(self):
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as socket_obj:
|
||||
self._socket = socket_obj
|
||||
self._handle_socket(socket=socket_obj)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def _start_watch(self):
|
||||
watch_manager = pyinotify.WatchManager()
|
||||
|
||||
for path in self.watched:
|
||||
logger.info(f"Watching ${path}")
|
||||
watch_manager.add_watch(path, pyinotify.ALL_EVENTS, rec=True)
|
||||
|
||||
event_handler = WatchHandler(indexer=self.indexer)
|
||||
notifier = pyinotify.ThreadedNotifier(watch_manager, event_handler)
|
||||
notifier.start()
|
||||
self._notifier = notifier
|
||||
|
||||
def run(self):
|
||||
collected = {}
|
||||
|
||||
for watched_path in self.watched:
|
||||
logger.info(f"Collecting files from ${watched_path}")
|
||||
collected.update(self.indexer.discover(watched_path))
|
||||
|
||||
for c in collected:
|
||||
logger.info(f"Indexing ${c}")
|
||||
self.indexer.index(c, collected[c])
|
||||
|
||||
try:
|
||||
self._start_watch()
|
||||
self._start_socket()
|
||||
except:
|
||||
self._socket.close()
|
||||
self._notifier.stop()
|
6
src/settings.py
Normal file
6
src/settings.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
SOCKET_HOST = "127.0.0.1"
|
||||
SOCKET_PORT = 65126
|
||||
|
||||
excludes = ["node_modules", ".git", ".venv"]
|
||||
types = [".py", ".js"]
|
||||
SIGNIFICANCE_THRESHOLD = 0.7
|
49
src/test_indexer.py
Normal file
49
src/test_indexer.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from indexer import Indexer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def indexer():
|
||||
return Indexer()
|
||||
|
||||
|
||||
def test_indexer_builds_trigram_set_for_given_document(indexer):
|
||||
mock_document = "now that's a doc"
|
||||
mock_path = "/home/documents/cool_doc"
|
||||
|
||||
indexer.index(path=mock_path, content=mock_document)
|
||||
|
||||
expected_trigrams = [
|
||||
"now",
|
||||
"ow ",
|
||||
"w t",
|
||||
" th",
|
||||
"tha",
|
||||
"hat",
|
||||
"at'",
|
||||
"t's",
|
||||
"'s ",
|
||||
"s a",
|
||||
" a ",
|
||||
"a d",
|
||||
" do",
|
||||
"doc",
|
||||
]
|
||||
|
||||
assert indexer.trigrams == {mock_path: set(expected_trigrams)}
|
||||
|
||||
|
||||
def test_indexer_preserves_previous_trigram_sets_on_index(indexer):
|
||||
mock_document_1 = "wow"
|
||||
mock_document_2 = "woa"
|
||||
mock_path_1 = "/home"
|
||||
mock_path_2 = "/somewhere_else"
|
||||
|
||||
indexer.index(path=mock_path_1, content=mock_document_1)
|
||||
|
||||
assert indexer.trigrams == {mock_path_1: set(["wow"])}
|
||||
|
||||
indexer.index(path=mock_path_2, content=mock_document_2)
|
||||
|
||||
assert indexer.trigrams == {mock_path_1: set(["wow"]), mock_path_2: set(["woa"])}
|
17
src/watcher.py
Normal file
17
src/watcher.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import pyinotify
|
||||
import attr
|
||||
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class WatchHandler(pyinotify.ProcessEvent):
|
||||
indexer = attr.ib()
|
||||
|
||||
def process_IN_MODIFY(self, event):
|
||||
c = self.indexer.discover(event.pathname)
|
||||
if c.get(event.pathname):
|
||||
logger.info(f"Reindexed {event.pathname}")
|
||||
self.indexer.index(event.pathname, c[event.pathname])
|
6
tasks.py
Normal file
6
tasks.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from invoke import task
|
||||
|
||||
|
||||
@task
|
||||
def lint(ctx):
|
||||
ctx.run("black *.py src")
|
Reference in a new issue