refactor: indices, settings

This commit is contained in:
Marc Cataford 2020-09-27 12:30:26 -04:00
parent 59795a5dec
commit 374685ae09
10 changed files with 234 additions and 123 deletions

View file

@ -1,11 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Optional
class IndexBase(ABC):
@abstractmethod
def index(self, content: str, haystack: Optional[List[str]]):
pass
@abstractmethod
def query(self, query: str) -> List[str]:
pass
class IndexerBase(ABC):
@abstractmethod
def discover(self, path: str):
pass
@abstractmethod
def index(self, path: str):
def index(self, paths: List[str]):
pass

View file

@ -1,9 +1,10 @@
import argparse
from pathlib import Path
from server import Server
from indexer import Indexer
from client import search
import settings
from settings import settings
parser = argparse.ArgumentParser()
@ -13,9 +14,10 @@ parser.add_argument("--q", required=False)
args = parser.parse_args()
if args.command == "start":
watched = [Path(p).expanduser() for p in settings.WATCHED]
server = Server(
indexer=Indexer(trigram_threshold=settings.SIGNIFICANCE_THRESHOLD),
watched=[".."],
indexer=Indexer(domain=watched),
watched=watched,
)
server.run()
elif args.command == "search":

View file

@ -1,7 +1,7 @@
import socket
import json
import settings
from settings import settings
from colors import highlight
@ -14,12 +14,12 @@ def search(query):
results = json.loads(s.recv(length).decode())
for result in results:
with open(result[0], "r") as infile:
highlighted_text = infile.read()[result[1] : result[2]].replace(
query, highlight(query)
)
line_number = result[3]
print(highlight(result[0]))
with open(result["key"], "r") as infile:
highlighted_text = infile.read()[
result["offset_start"] : result["offset_end"]
].replace(query, highlight(query))
line_number = result["line_start"]
print(highlight(result["key"]))
for l in highlighted_text.split("\n"):
print(f"{highlight(line_number)} {l}")
line_number += 1

View file

@ -1 +1,9 @@
SETTINGS_KEYS = [
"WATCHED",
"SOCKET_PORT",
"SOCKET_HOST",
"EXCLUDES",
"FILE_TYPES",
"SIGNIFICANCE_THRESHOLD",
]
QUERY_STRING_LENGTH = 1024

View file

@ -1,111 +1,82 @@
from base import IndexerBase
from pathlib import Path
from typing import Dict
from typing import Dict, List
import re
import attr
import settings
from settings import settings
from trigram_index import TrigramIndex
from line_index import LineIndex
from logger import get_logger
logger = get_logger(__name__)
@attr.s
class SearchResult:
key = attr.ib()
offset_start = attr.ib()
offset_end = attr.ib()
line_start = attr.ib()
line_end = attr.ib()
def to_dict(self):
return {
"key": self.key,
"offset_start": self.offset_start,
"offset_end": self.offset_end,
"line_start": self.line_start,
"line_end": self.line_end,
}
@attr.s
class Indexer(IndexerBase):
_trigrams = attr.ib(default=attr.Factory(dict))
# Indices
_trigram_index = attr.ib(default=attr.Factory(TrigramIndex))
_line_index = attr.ib(default=attr.Factory(LineIndex))
domain = attr.ib(default=attr.Factory(list))
_collected = attr.ib(default=attr.Factory(dict))
_lines = attr.ib(default=attr.Factory(dict))
trigram_threshold = attr.ib(default=0)
@property
def trigrams(self):
return dict(self._trigrams)
def index(self, paths: List[str]):
discovered = []
for path in paths:
discovered.extend(self._discover(path))
def discover(self, path_root: str) -> Dict[str, str]:
logger.info(f"Collecting {path_root}")
collected = {}
logger.info(f"[Discovery] Discovered {len(discovered)} files.")
if any([x in path_root for x in settings.excludes]):
return {}
current = Path(path_root)
if current.is_dir():
for child_path in current.iterdir():
collected.update(self.discover(str(child_path)))
self._collected.update(collected)
return dict(self._collected)
if current.suffix not in settings.types:
return {}
try:
with open(current, "r") as infile:
self._collected[str(current.resolve())] = infile.read()
except:
pass
return dict(self._collected)
def index(self, path: str, content: str):
p = Path(path)
self._trigrams[path] = set()
for idx in range(len(content) - 2):
self._trigrams[path].add(content[idx : idx + 3])
self._lines[path] = {}
content = self._collected[path]
current, count = 0, 0
for line in self._collected[path].split("\n"):
self._lines[path][count] = (current, current + len(line))
current += len(line)
count += 1
self._preload(discovered)
for path in self._collected:
self._process(path)
def query(self, query: str):
trigram_results = self.search_trigrams(query)
confirmed = self.search_content(query, trigram_results)
trigram_results = self._trigram_index.query(query)
return confirmed
logger.info(f"Narrowed down to {len(trigram_results)} files via trigram search")
def find_closest_line(self, path, index, offset=0):
content = self._lines[path]
for l in content:
if content[l][0] <= index <= content[l][1]:
return l
logger.error(f"{path} {index}")
logger.error(content)
return 0
def search_content(self, query: str, leads):
confirmed = []
uniques = 0
for lead in leads:
for lead in trigram_results:
lead_content = self._collected[lead[1]]
results = re.finditer(query, lead_content)
hits_in_lead = []
for hit in results:
start_line = self.find_closest_line(lead[1], hit.start())
end_line = self.find_closest_line(lead[1], hit.end())
start_line_range = max(0, start_line - 5)
end_line_range = min(len(self._lines[lead[1]]) - 1, end_line + 5)
start_line, end_line = self._find_line_range(
lead[1], hit.start(), hit.end()
)
start_offset = self._line_index.query(lead[1])[start_line][0]
end_offset = self._line_index.query(lead[1])[end_line][1]
hits_in_lead.append(
(
lead[1],
self._lines[lead[1]][start_line_range][0],
self._lines[lead[1]][end_line_range][1],
start_line_range,
end_line_range,
SearchResult(
key=lead[1],
offset_start=start_offset,
offset_end=end_offset,
line_start=start_line,
line_end=end_line,
)
)
@ -114,22 +85,62 @@ class Indexer(IndexerBase):
uniques += 1
logger.info(f"{len(confirmed)} hits in {uniques} files.")
return confirmed
return [r.to_dict() for r in confirmed]
def search_trigrams(self, query: str):
query_trigrams = [query[idx : idx + 3] for idx in range(len(query) - 2)]
results = []
def _discover(self, path_root: str) -> Dict[str, str]:
collected = []
current = Path(path_root)
for item in self.trigrams:
shared = self.trigrams[item].intersection(query_trigrams)
ratio = len(shared) / len(query_trigrams)
if ratio < self.trigram_threshold:
continue
# Avoid any excluded paths
if any([current.match(x) for x in settings.EXCLUDES]):
logger.info(f"[Discovery] {path_root} excluded.")
return []
results.append((ratio, item, list(shared)))
if current.is_dir():
logger.info(list(current.iterdir()))
for child_path in current.iterdir():
collected.extend(self._discover(str(child_path)))
results.sort(reverse=True, key=lambda x: x[0])
return collected
logger.info(f"Narrowed down to {len(results)} files via trigram search")
if current.suffix not in settings.FILE_TYPES:
return []
return results
logger.info(f"Collected {path_root}")
return [path_root]
def _preload(self, discovered: List[str]):
for discovered_file in discovered:
try:
with open(discovered_file, "r") as infile:
self._collected[discovered_file] = infile.read()
logger.info(f"[Preloading] Loaded {discovered_file} in memory")
except:
logger.error(f"Could not read {discovered_file}, skipping.")
def _process(self, path: str):
p = Path(path)
content = self._collected[path]
self._trigram_index.index(path, content)
content = self._collected[path]
self._line_index.index(path, content)
def _find_closest_line(self, path, index):
content = self._line_index.query(path)
for l in content:
if content[l][0] <= index <= content[l][1]:
return l
return 0
def _find_line_range(self, key, start, end, padding=5):
start_line = self._find_closest_line(key, start)
end_line = self._find_closest_line(key, end)
start_line_range = max(0, start_line - 5)
end_line_range = min(len(self._line_index.query(key)) - 1, end_line + 5)
return (start_line_range, end_line_range)

19
src/line_index.py Normal file
View file

@ -0,0 +1,19 @@
from base import IndexBase
import attr
@attr.s
class LineIndex(IndexBase):
_lines = attr.ib(default=attr.Factory(dict))
def index(self, key: str, content: str):
self._lines[key] = {}
current, count = 0, 0
for line in content.split("\n"):
self._lines[key][count] = (current, current + len(line))
current += len(line)
count += 1
def query(self, key: str):
return self._lines[key]

View file

@ -8,7 +8,7 @@ from watcher import WatchHandler
from indexer import Indexer
from constants import QUERY_STRING_LENGTH
import settings
from settings import settings
from logger import get_logger
@ -68,13 +68,7 @@ class Server:
def run(self):
collected = {}
for watched_path in self.watched:
logger.info(f"Collecting files from ${watched_path}")
collected.update(self.indexer.discover(watched_path))
for c in collected:
logger.info(f"Indexing ${c}")
self.indexer.index(c, collected[c])
self.indexer.index(self.watched)
try:
self._start_watch()

View file

@ -1,6 +1,43 @@
SOCKET_HOST = "127.0.0.1"
SOCKET_PORT = 65126
import json
excludes = ["node_modules", ".git", ".venv"]
types = [".py", ".js"]
SIGNIFICANCE_THRESHOLD = 0.7
from pathlib import Path
import attr
from constants import SETTINGS_KEYS
SETTINGS_PATH = "~/.codesearchrc"
default_settings = {
"SOCKET_HOST": "127.0.0.1",
"SOCKET_PORT": 65126,
"EXCLUDES": [],
"FILE_TYPES": [],
"SIGNIFICANCE_THRESHOLD": 0,
"WATCHED": [],
}
@attr.s
class Settings:
settings = attr.ib(default=attr.Factory(dict))
def from_file(self, path: str):
settings_path = Path(SETTINGS_PATH).expanduser()
if not settings_path.exists():
self.settings = default_settings
return
with open(path, "r") as settings_file:
self.settings = json.load(settings_file)
def __getattr__(self, key):
if key not in SETTINGS_KEYS:
raise KeyError(f"{key} not a valid settings property")
return self.settings[key]
settings = Settings()
settings.from_file(Path(SETTINGS_PATH).expanduser())

36
src/trigram_index.py Normal file
View file

@ -0,0 +1,36 @@
from typing import List, Optional
import attr
from settings import settings
from base import IndexBase
@attr.s
class TrigramIndex(IndexBase):
_trigrams = attr.ib(default=attr.Factory(dict))
_threshold = attr.ib(default=settings.SIGNIFICANCE_THRESHOLD)
def index(self, key: str, content: str):
self._trigrams[key] = self._trigramize(content)
def query(self, query: str, haystack: Optional[List[str]] = None) -> List[str]:
if not haystack:
haystack = self._trigrams
query_trigrams = self._trigramize(query)
results = []
for item in haystack:
shared = self._trigrams[item].intersection(query_trigrams)
ratio = len(shared) / len(query_trigrams)
if ratio < self._threshold:
continue
results.append((ratio, item, list(shared)))
results.sort(reverse=True, key=lambda x: x[0])
return results
def _trigramize(self, content: str) -> List[str]:
return {content[pos : pos + 3].lower() for pos in range(len(content) - 2)}

View file

@ -11,7 +11,4 @@ class WatchHandler(pyinotify.ProcessEvent):
indexer = attr.ib()
def process_IN_MODIFY(self, event):
c = self.indexer.discover(event.pathname)
if c.get(event.pathname):
logger.info(f"Reindexed {event.pathname}")
self.indexer.index(event.pathname, c[event.pathname])
self.indexer.index([event.pathname])