Skip to content

Commit 3c7489c

Browse files
authored
Merge pull request #381 from informatics-lab/fix-datetime-support
Fix datetime support
2 parents 17f330a + 9b9cee6 commit 3c7489c

File tree

6 files changed

+82
-18
lines changed

6 files changed

+82
-18
lines changed

forest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
.. automodule:: forest.presets
2424
2525
"""
26-
__version__ = '0.17.3'
26+
__version__ = '0.17.4'
2727

2828
from .config import *
2929
from . import (

forest/db/database.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010
from .connection import Connection
11+
from forest import mark
1112

1213

1314
__all__ = [
@@ -373,9 +374,9 @@ def insert_pressure(self, path, variable, pressure, i):
373374
(SELECT id FROM pressure WHERE value=:pressure AND i=:i))
374375
""", dict(path=path, variable=variable, pressure=pressure, i=i))
375376

377+
@mark.sql_sanitize_time("initial_time")
376378
def valid_times(self, pattern, variable, initial_time):
377379
"""Valid times associated with search criteria"""
378-
initial_time = self.sanitize_time(initial_time)
379380
query = self.valid_times_query(pattern, variable, initial_time)
380381
self.cursor.execute(query, dict(
381382
variable=variable,
@@ -417,22 +418,9 @@ def valid_times_query(pattern, variable, initial_time):
417418
variable=variable,
418419
pattern=pattern)
419420

420-
@staticmethod
421-
def sanitize_time(value):
422-
"""Query-compatible equivalent of value"""
423-
fmt = "%Y-%m-%d %H:%M:%S"
424-
if value is None:
425-
return value
426-
elif isinstance(value, str):
427-
return value
428-
elif isinstance(value, np.datetime64):
429-
return pd.to_datetime(str(value)).strftime(fmt)
430-
else:
431-
return value.strftime(fmt)
432-
421+
@mark.sql_sanitize_time("initial_time")
433422
def pressures(self, pattern=None, variable=None, initial_time=None):
434423
"""Select pressures from database"""
435-
initial_time = self.sanitize_time(initial_time)
436424
query = self.pressures_query(pattern, variable, initial_time)
437425
self.cursor.execute(query, dict(
438426
variable=variable,

forest/db/locate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from .connection import Connection
55
from forest.exceptions import SearchFail
6+
from forest import mark
67

78

89
__all__ = [
@@ -17,6 +18,7 @@ def __init__(self, connection, directory=None):
1718
self.connection = connection
1819
self.cursor = self.connection.cursor()
1920

21+
@mark.sql_sanitize_time("initial_time", "valid_time")
2022
def locate(
2123
self,
2224
pattern,
@@ -66,6 +68,7 @@ def locate(
6668
return path, (ti, pi)
6769
raise SearchFail("Could not locate: {}".format(pattern))
6870

71+
@mark.sql_sanitize_time("initial_time", "valid_time")
6972
@lru_cache()
7073
def file_names(self, pattern, variable, initial_time, valid_time):
7174
self.cursor.execute("""

forest/mark.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Decorators to mark classes and functions"""
2+
import inspect
23
from unittest.mock import Mock
34
from contextlib import contextmanager
45
from functools import wraps
56
from forest.observe import Observable
7+
import pandas as pd
8+
import numpy as np
69

710

811
def component(cls):
@@ -29,3 +32,48 @@ def disable(obj, method_name):
2932
setattr(obj, method_name, Mock())
3033
yield
3134
setattr(obj, method_name, method)
35+
36+
37+
def sql_sanitize_time(*labels):
38+
"""Decorator to protect SQL statements from unsupported datetime types
39+
40+
>>> @sql_sanitize_time("b", "c")
41+
... def method(self, a, b, c=None, d=False):
42+
... # b and c will be converted to a str compatible with SQL queries
43+
... pass
44+
45+
"""
46+
def outer(f):
47+
parameters = inspect.signature(f).parameters
48+
49+
# Get positional index
50+
index = {}
51+
for i, name in enumerate(parameters):
52+
if name in labels:
53+
index[name] = i
54+
55+
def inner(*args, **kwargs):
56+
args = list(args)
57+
for label in labels:
58+
if label in kwargs:
59+
kwargs[label] = sanitize_time(kwargs[label])
60+
else:
61+
i = index[label]
62+
if i < len(args):
63+
args[i] = sanitize_time(args[i])
64+
return f(*args, **kwargs)
65+
return inner
66+
return outer
67+
68+
69+
def sanitize_time(value):
70+
"""Query-compatible equivalent of value"""
71+
fmt = "%Y-%m-%d %H:%M:%S"
72+
if value is None:
73+
return value
74+
elif isinstance(value, str):
75+
return value
76+
elif isinstance(value, np.datetime64):
77+
return pd.to_datetime(str(value)).strftime(fmt)
78+
else:
79+
return value.strftime(fmt)

test/test_db_database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77

88
import forest.db.database as database
9+
import forest.mark
910

1011

1112
def _create_db():
@@ -203,5 +204,5 @@ def test_Database_valid_times_given_datetime_like_objects(initial_time):
203204
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
204205
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
205206
])
206-
def test_Database_sanitize_datetime_like_objects(time):
207-
assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00"
207+
def test_sanitize_datetime_like_objects(time):
208+
assert forest.mark.sanitize_time(time) == "2020-01-01 00:00:00"

test/test_db_locate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
1+
import pytest
12
import unittest
3+
from unittest.mock import Mock, sentinel
24
import sqlite3
35
import datetime as dt
6+
import cftime
7+
import numpy as np
8+
import pandas as pd
49
from forest import db
510
from forest.exceptions import SearchFail
611

712

13+
@pytest.mark.parametrize("time", [
14+
pytest.param("2020-01-01 00:00:00", id="str"),
15+
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
16+
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
17+
pytest.param(np.datetime64("2020-01-01 00:00:00", "ns"), id="numpy"),
18+
pytest.param(pd.Timestamp("2020-01-01 00:00:00"), id="pandas"),
19+
])
20+
def test_locator_file_names_supports_datetime_types(time):
21+
path = "file.nc"
22+
variable = "variable"
23+
database = db.Database.connect(":memory:")
24+
database.insert_file_name(path, "2020-01-01 00:00:00")
25+
database.insert_time(path, variable, "2020-01-01 00:00:00", 0)
26+
locator = db.Locator(database.connection)
27+
result = locator.file_names(path, variable, time, time)
28+
expect = [path]
29+
assert expect == result
30+
31+
832
class TestLocate(unittest.TestCase):
933
def setUp(self):
1034
self.database = db.Database.connect(":memory:")

0 commit comments

Comments
 (0)