Skip to content

Commit

Permalink
Export recent changes.
Browse files Browse the repository at this point in the history
- Fix issues with path escaping in MySQL.
- Remove in-place operators from RDF primitives.
- Minor bugfixes in other areas.
  • Loading branch information
grr-export committed Jun 27, 2019
1 parent ef3062e commit c2fd405
Show file tree
Hide file tree
Showing 17 changed files with 229 additions and 132 deletions.
50 changes: 0 additions & 50 deletions grr/core/grr_response_core/lib/rdfvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,40 +522,24 @@ def __and__(self, other):
def __rand__(self, other):
return self._value & other

def __iand__(self, other):
self._value &= other
return self

def __or__(self, other):
return self._value | other

def __ror__(self, other):
return self._value | other

def __ior__(self, other):
self._value |= other
return self

def __add__(self, other):
return self._value + other

def __radd__(self, other):
return self._value + other

def __iadd__(self, other):
self._value += other
return self

def __sub__(self, other):
return self._value - other

def __rsub__(self, other):
return other - self._value

def __isub__(self, other):
self._value -= other
return self

def __mul__(self, other):
return self._value * other

Expand Down Expand Up @@ -707,15 +691,6 @@ def __add__(self, other):

return NotImplemented

def __iadd__(self, other):
# TODO(hanuszczak): Disallow `float` initialization.
if isinstance(other, (int, float, Duration)):
# Assume other is in seconds
self._value += compatibility.builtins.int(other * self.converter)
return self

return NotImplemented

def __mul__(self, other):
# TODO(hanuszczak): Disallow `float` initialization.
if isinstance(other, (int, float, Duration)):
Expand All @@ -738,15 +713,6 @@ def __sub__(self, other):

return NotImplemented

def __isub__(self, other):
# TODO(hanuszczak): Disallow `float` initialization.
if isinstance(other, (int, float, Duration)):
# Assume other is in seconds
self._value -= compatibility.builtins.int(other * self.converter)
return self

return NotImplemented

@classmethod
def _ParseFromHumanReadable(cls, string, eoy=False):
"""Parse a human readable string of a timestamp (in local time).
Expand Down Expand Up @@ -890,14 +856,6 @@ def __add__(self, other):

return NotImplemented

def __iadd__(self, other):
if isinstance(other, (int, float, Duration)):
# Assume other is in seconds
self._value += other
return self

return NotImplemented

def __mul__(self, other):
if isinstance(other, (int, float, Duration)):
return self.__class__(int(self._value * other))
Expand All @@ -914,14 +872,6 @@ def __sub__(self, other):

return NotImplemented

def __isub__(self, other):
if isinstance(other, (int, float, Duration)):
# Assume other is in seconds
self._value -= other
return self

return NotImplemented

def __abs__(self):
return Duration(abs(self._value))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def __init__(self,
self._primary = _InstantiateBlobStore(primary)
self._secondary = _InstantiateBlobStore(secondary)
self._queue = queue.Queue(_SECONDARY_WRITE_QUEUE_MAX_LENGTH)
self._thread_running = True
self._thread = threading.Thread(target=self._WriteBlobsIntoSecondary)
self._thread.daemon = True
self._thread.start()
self._thread_running = True

def WriteBlobs(self,
blob_id_data_map):
Expand Down Expand Up @@ -140,3 +140,4 @@ def _WriteBlobsIntoSecondary(self):
# Failed writes to secondary are not critical, because primary is read
# from.
logging.warn(e)
self._queue.task_done()
54 changes: 48 additions & 6 deletions grr/server/grr_response_server/blob_stores/dual_blob_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,44 @@
from __future__ import division
from __future__ import unicode_literals

from absl.testing import absltest
from absl import app
from future.moves import queue

from grr_response_core.lib.util import compatibility
from grr_response_server import blob_store_test_mixin
from grr_response_server.blob_stores import dual_blob_store
from grr_response_server.blob_stores import registry_init
from grr_response_server.databases import mem_blobs
from grr_response_server.rdfvalues import objects as rdf_objects
from grr.test_lib import test_lib


def _StopBackgroundThread(dual_bs):
"""Stops the background thread that writes to the secondary."""
dual_bs._thread_running = False
# Unblock _queue.get() to recheck loop condition.
dual_bs._queue.put_nowait({})
try:
dual_bs._queue.put_nowait({})
except queue.Full:
pass # At least one entry is in the queue, which works as well.
dual_bs._thread.join(timeout=1)


class DualBlobStoreTest(blob_store_test_mixin.BlobStoreTestMixin):
def _WaitUntilQueueIsEmpty(dual_bs):
dual_bs._queue.join()


class DualBlobStoreTest(
blob_store_test_mixin.BlobStoreTestMixin,
test_lib.GRRBaseTest,
):

@classmethod
def setUpClass(cls):
registry_init.RegisterBlobStores()

def CreateBlobStore(self):
backing_store_name = compatibility.GetName(mem_blobs.InMemoryDBBlobsMixin)
backing_store_name = compatibility.GetName(mem_blobs.InMemoryBlobStore)
bs = dual_blob_store.DualBlobStore(backing_store_name, backing_store_name)
return bs, lambda: _StopBackgroundThread(bs)

Expand All @@ -38,6 +55,7 @@ def testSingleBlobIsWrittenToSecondary(self):
blob_data = b"abcdef"

self.blob_store.WriteBlobs({blob_id: blob_data})
_WaitUntilQueueIsEmpty(self.blob_store.delegate)

result = self.secondary.ReadBlob(blob_id)
self.assertEqual(result, blob_data)
Expand All @@ -47,6 +65,7 @@ def testMultipleBlobsAreWrittenToSecondary(self):
blob_data = [b"a" * i for i in range(10)]

self.blob_store.WriteBlobs(dict(zip(blob_ids, blob_data)))
_WaitUntilQueueIsEmpty(self.blob_store.delegate)

result = self.secondary.ReadBlobs(blob_ids)
self.assertEqual(result, dict(zip(blob_ids, blob_data)))
Expand All @@ -56,14 +75,37 @@ def testWritesToPrimaryAreNotBlockedBySecondary(self):

limit = dual_blob_store._SECONDARY_WRITE_QUEUE_MAX_LENGTH + 1
blob_ids = [
rdf_objects.BlobID((b"%02d34567" % i) * 4) for i in range(limit)
rdf_objects.BlobID((b"%02d234567" % i) * 4) for i in range(limit)
]
blob_data = [b"a" * i for i in range(limit)]

self.blob_store.WriteBlobs(dict(zip(blob_ids, blob_data)))
result = self.blob_store.ReadBlobs(blob_ids)
self.assertEqual(result, dict(zip(blob_ids, blob_data)))

def testDiscardedSecondaryWritesAreMeasured(self):
_StopBackgroundThread(self.blob_store.delegate)

with self.assertStatsCounterDelta(
0,
"dual_blob_store_discard_count",
fields=["secondary", "InMemoryBlobStore"]):
for i in range(dual_blob_store._SECONDARY_WRITE_QUEUE_MAX_LENGTH):
blob_id = rdf_objects.BlobID((b"%02d234567" % i) * 4)
blob = b"a" * i
self.blob_store.WriteBlobs({blob_id: blob})

self.assertEqual(self.blob_store.delegate._queue.qsize(),
dual_blob_store._SECONDARY_WRITE_QUEUE_MAX_LENGTH)
with self.assertStatsCounterDelta(
3,
"dual_blob_store_discard_count",
fields=["secondary", "InMemoryBlobStore"]):
for i in range(3, 6):
blob_id = rdf_objects.BlobID((b"%02d234567" % i) * 4)
blob = b"a" * i
self.blob_store.WriteBlobs({blob_id: blob})


if __name__ == "__main__":
absltest.main()
app.run(test_lib.main)
3 changes: 3 additions & 0 deletions grr/server/grr_response_server/blob_stores/registry_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from grr_response_server.blob_stores import db_blob_store
from grr_response_server.blob_stores import dual_blob_store
from grr_response_server.blob_stores import memory_stream_bs
from grr_response_server.databases import mem_blobs


def RegisterBlobStores():
Expand All @@ -20,3 +21,5 @@ def RegisterBlobStores():
blob_store.REGISTRY[compatibility.GetName(
memory_stream_bs.MemoryStreamBlobStore
)] = memory_stream_bs.MemoryStreamBlobStore
blob_store.REGISTRY[compatibility.GetName(
mem_blobs.InMemoryBlobStore)] = mem_blobs.InMemoryBlobStore
55 changes: 55 additions & 0 deletions grr/server/grr_response_server/databases/db_paths_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,61 @@ def testListChildPathInfosTimestampStatAndHashValue(self):
self.assertEqual(results_3[0].stat_entry.st_size, 1337)
self.assertEqual(results_3[0].hash_entry.sha256, b"thud")

def testListChildPathInfosBackslashes(self):
client_id = db_test_utils.InitializeClient(self.db)

path_info_1 = rdf_objects.PathInfo.OS(components=("\\", "\\\\", "\\\\\\"))
path_info_2 = rdf_objects.PathInfo.OS(components=("\\", "\\\\\\", "\\\\"))
path_info_3 = rdf_objects.PathInfo.OS(components=("\\", "foo\\bar", "baz"))
self.db.WritePathInfos(client_id, [path_info_1, path_info_2, path_info_3])

results_0 = self.db.ListChildPathInfos(
client_id=client_id,
path_type=rdf_objects.PathInfo.PathType.OS,
components=("\\",))
self.assertLen(results_0, 3)
self.assertEqual(results_0[0].components, ("\\", "\\\\"))
self.assertEqual(results_0[1].components, ("\\", "\\\\\\"))
self.assertEqual(results_0[2].components, ("\\", "foo\\bar"))

results_1 = self.db.ListChildPathInfos(
client_id=client_id,
path_type=rdf_objects.PathInfo.PathType.OS,
components=("\\", "\\\\"))
self.assertLen(results_1, 1)
self.assertEqual(results_1[0].components, ("\\", "\\\\", "\\\\\\"))

results_2 = self.db.ListChildPathInfos(
client_id=client_id,
path_type=rdf_objects.PathInfo.PathType.OS,
components=("\\", "\\\\\\"))
self.assertLen(results_2, 1)
self.assertEqual(results_2[0].components, ("\\", "\\\\\\", "\\\\"))

results_3 = self.db.ListChildPathInfos(
client_id=client_id,
path_type=rdf_objects.PathInfo.PathType.OS,
components=("\\", "foo\\bar"))
self.assertLen(results_3, 1)
self.assertEqual(results_3[0].components, ("\\", "foo\\bar", "baz"))

def testListChildPathInfosTSKRootVolume(self):
client_id = db_test_utils.InitializeClient(self.db)
volume = "\\\\?\\Volume{2d4fbbd3-0000-0000-0000-100000000000}"

path_info = rdf_objects.PathInfo.TSK(components=(volume, "foobar.txt"))
path_info.stat_entry.st_size = 42
self.db.WritePathInfos(client_id, [path_info])

results = self.db.ListChildPathInfos(
client_id=client_id,
path_type=rdf_objects.PathInfo.PathType.TSK,
components=(volume,))

self.assertLen(results, 1)
self.assertEqual(results[0].components, (volume, "foobar.txt"))
self.assertEqual(results[0].stat_entry.st_size, 42)

def testReadPathInfosHistoriesEmpty(self):
client_id = db_test_utils.InitializeClient(self.db)
result = self.db.ReadPathInfosHistories(client_id,
Expand Down
19 changes: 19 additions & 0 deletions grr/server/grr_response_server/databases/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ def EscapeWildcards(string):
return string.replace("%", r"\%").replace("_", r"\_")


def EscapeBackslashes(string):
"""Escapes backslash characters for strings intended to be used with `LIKE`.
Backslashes work in mysterious ways: sometimes they do need to be escaped,
sometimes this is being done automatically when passing values. Combined with
unclear rules of `LIKE`, this can be very confusing.
https://what.thedailywtf.com/topic/13989/mysql-backslash-escaping
Args:
string: A string to escape.
Returns:
An escaped string.
"""
precondition.AssertType(string, Text)
return string.replace("\\", "\\\\")


def ClientIdFromGrrMessage(m):
if m.queue:
return m.queue.Split()[0]
Expand Down
10 changes: 10 additions & 0 deletions grr/server/grr_response_server/databases/mem_blobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from __future__ import division
from __future__ import unicode_literals

import threading


from future.utils import itervalues

Expand Down Expand Up @@ -66,3 +68,11 @@ def ReadHashBlobReferences(self, hashes):
result[hash_id] = None

return result


class InMemoryBlobStore(InMemoryDBBlobsMixin):

def __init__(self):
self.blobs = {}
self.blob_refs_by_hashes = {}
self.lock = threading.RLock()
5 changes: 3 additions & 2 deletions grr/server/grr_response_server/databases/mysql_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,11 @@ def ListDescendentPathInfos(self,
query = ""

path = mysql_utils.ComponentsToPath(components)
escaped_path = db_utils.EscapeWildcards(db_utils.EscapeBackslashes(path))
values = {
"client_id": db_utils.ClientIDToInt(client_id),
"path_type": int(path_type),
"path": db_utils.EscapeWildcards(path),
"path": escaped_path,
}

query += """
Expand Down Expand Up @@ -390,7 +391,7 @@ def ListDescendentPathInfos(self,
query += """
WHERE p.client_id = %(client_id)s
AND p.path_type = %(path_type)s
AND path LIKE concat(%(path)s, '/%%')
AND path LIKE CONCAT(%(path)s, '/%%')
"""

if max_depth is not None:
Expand Down
Loading

0 comments on commit c2fd405

Please sign in to comment.