mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2025-02-23 10:59:04 +08:00
295 lines
11 KiB
Python
295 lines
11 KiB
Python
# From: gluonts/src/gluonts/time_feature/_base.py
|
|
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
# You may not use this file except in compliance with the License.
|
|
# A copy of the License is located at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# or in the "license" file accompanying this file. This file is distributed
|
|
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
|
# express or implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pandas.tseries import offsets
|
|
from pandas.tseries.frequencies import to_offset
|
|
from distutils.util import strtobool
|
|
from datetime import datetime
|
|
|
|
class TimeFeature:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
pass
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + "()"
|
|
|
|
|
|
class SecondOfMinute(TimeFeature):
|
|
"""Minute of hour encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return index.second / 59.0 - 0.5
|
|
|
|
|
|
class MinuteOfHour(TimeFeature):
|
|
"""Minute of hour encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return index.minute / 59.0 - 0.5
|
|
|
|
|
|
class HourOfDay(TimeFeature):
|
|
"""Hour of day encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return index.hour / 23.0 - 0.5
|
|
|
|
|
|
class DayOfWeek(TimeFeature):
|
|
"""Hour of day encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return index.dayofweek / 6.0 - 0.5
|
|
|
|
|
|
class DayOfMonth(TimeFeature):
|
|
"""Day of month encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return (index.day - 1) / 30.0 - 0.5
|
|
|
|
|
|
class DayOfYear(TimeFeature):
|
|
"""Day of year encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return (index.dayofyear - 1) / 365.0 - 0.5
|
|
|
|
|
|
class MonthOfYear(TimeFeature):
|
|
"""Month of year encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return (index.month - 1) / 11.0 - 0.5
|
|
|
|
|
|
class WeekOfYear(TimeFeature):
|
|
"""Week of year encoded as value between [-0.5, 0.5]"""
|
|
|
|
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
|
return (index.isocalendar().week - 1) / 52.0 - 0.5
|
|
|
|
|
|
def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
|
|
"""
|
|
Returns a list of time features that will be appropriate for the given frequency string.
|
|
Parameters
|
|
----------
|
|
freq_str
|
|
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
|
|
"""
|
|
|
|
features_by_offsets = {
|
|
offsets.YearEnd: [],
|
|
offsets.QuarterEnd: [MonthOfYear],
|
|
offsets.MonthEnd: [MonthOfYear],
|
|
offsets.Week: [DayOfMonth, WeekOfYear],
|
|
offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
|
|
offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
|
|
offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
|
|
offsets.Minute: [
|
|
MinuteOfHour,
|
|
HourOfDay,
|
|
DayOfWeek,
|
|
DayOfMonth,
|
|
DayOfYear,
|
|
],
|
|
offsets.Second: [
|
|
SecondOfMinute,
|
|
MinuteOfHour,
|
|
HourOfDay,
|
|
DayOfWeek,
|
|
DayOfMonth,
|
|
DayOfYear,
|
|
],
|
|
}
|
|
|
|
offset = to_offset(freq_str)
|
|
|
|
for offset_type, feature_classes in features_by_offsets.items():
|
|
if isinstance(offset, offset_type):
|
|
return [cls() for cls in feature_classes]
|
|
|
|
supported_freq_msg = f"""
|
|
Unsupported frequency {freq_str}
|
|
The following frequencies are supported:
|
|
Y - yearly
|
|
alias: A
|
|
M - monthly
|
|
W - weekly
|
|
D - daily
|
|
B - business days
|
|
H - hourly
|
|
T - minutely
|
|
alias: min
|
|
S - secondly
|
|
"""
|
|
raise RuntimeError(supported_freq_msg)
|
|
|
|
|
|
def time_features(dates, freq='h'):
|
|
return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
|
|
|
|
|
|
def convert_tsf_to_dataframe(
|
|
full_file_path_and_name,
|
|
replace_missing_vals_with="NaN",
|
|
value_column_name="series_value",
|
|
):
|
|
col_names = []
|
|
col_types = []
|
|
all_data = {}
|
|
line_count = 0
|
|
frequency = None
|
|
forecast_horizon = None
|
|
contain_missing_values = None
|
|
contain_equal_length = None
|
|
found_data_tag = False
|
|
found_data_section = False
|
|
started_reading_data_section = False
|
|
|
|
with open(full_file_path_and_name, "r", encoding="cp1252") as file:
|
|
for line in file:
|
|
# Strip white space from start/end of line
|
|
line = line.strip()
|
|
|
|
if line:
|
|
if line.startswith("@"): # Read meta-data
|
|
if not line.startswith("@data"):
|
|
line_content = line.split(" ")
|
|
if line.startswith("@attribute"):
|
|
if (
|
|
len(line_content) != 3
|
|
): # Attributes have both name and type
|
|
raise Exception("Invalid meta-data specification.")
|
|
|
|
col_names.append(line_content[1])
|
|
col_types.append(line_content[2])
|
|
else:
|
|
if (
|
|
len(line_content) != 2
|
|
): # Other meta-data have only values
|
|
raise Exception("Invalid meta-data specification.")
|
|
|
|
if line.startswith("@frequency"):
|
|
frequency = line_content[1]
|
|
elif line.startswith("@horizon"):
|
|
forecast_horizon = int(line_content[1])
|
|
elif line.startswith("@missing"):
|
|
contain_missing_values = bool(
|
|
strtobool(line_content[1])
|
|
)
|
|
elif line.startswith("@equallength"):
|
|
contain_equal_length = bool(strtobool(line_content[1]))
|
|
|
|
else:
|
|
if len(col_names) == 0:
|
|
raise Exception(
|
|
"Missing attribute section. Attribute section must come before data."
|
|
)
|
|
|
|
found_data_tag = True
|
|
elif not line.startswith("#"):
|
|
if len(col_names) == 0:
|
|
raise Exception(
|
|
"Missing attribute section. Attribute section must come before data."
|
|
)
|
|
elif not found_data_tag:
|
|
raise Exception("Missing @data tag.")
|
|
else:
|
|
if not started_reading_data_section:
|
|
started_reading_data_section = True
|
|
found_data_section = True
|
|
all_series = []
|
|
|
|
for col in col_names:
|
|
all_data[col] = []
|
|
|
|
full_info = line.split(":")
|
|
|
|
if len(full_info) != (len(col_names) + 1):
|
|
raise Exception("Missing attributes/values in series.")
|
|
|
|
series = full_info[len(full_info) - 1]
|
|
series = series.split(",")
|
|
|
|
if len(series) == 0:
|
|
raise Exception(
|
|
"A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
|
|
)
|
|
|
|
numeric_series = []
|
|
|
|
for val in series:
|
|
if val == "?":
|
|
numeric_series.append(replace_missing_vals_with)
|
|
else:
|
|
numeric_series.append(float(val))
|
|
|
|
if numeric_series.count(replace_missing_vals_with) == len(
|
|
numeric_series
|
|
):
|
|
raise Exception(
|
|
"All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
|
|
)
|
|
|
|
all_series.append(pd.Series(numeric_series).array)
|
|
|
|
for i in range(len(col_names)):
|
|
att_val = None
|
|
if col_types[i] == "numeric":
|
|
att_val = int(full_info[i])
|
|
elif col_types[i] == "string":
|
|
att_val = str(full_info[i])
|
|
elif col_types[i] == "date":
|
|
att_val = datetime.strptime(
|
|
full_info[i], "%Y-%m-%d %H-%M-%S"
|
|
)
|
|
else:
|
|
raise Exception(
|
|
"Invalid attribute type."
|
|
) # Currently, the code supports only numeric, string and date types. Extend this as required.
|
|
|
|
if att_val is None:
|
|
raise Exception("Invalid attribute value.")
|
|
else:
|
|
all_data[col_names[i]].append(att_val)
|
|
|
|
line_count = line_count + 1
|
|
|
|
if line_count == 0:
|
|
raise Exception("Empty file.")
|
|
if len(col_names) == 0:
|
|
raise Exception("Missing attribute section.")
|
|
if not found_data_section:
|
|
raise Exception("Missing series information under data section.")
|
|
|
|
all_data[value_column_name] = all_series
|
|
loaded_data = pd.DataFrame(all_data)
|
|
|
|
return (
|
|
loaded_data,
|
|
frequency,
|
|
forecast_horizon,
|
|
contain_missing_values,
|
|
contain_equal_length,
|
|
)
|