first commit'

main
陆绍超 10 months ago
commit ed142fa8e7
  1. 3
      .idea/.gitignore
  2. 8
      .idea/AIShuiwu.iml
  3. 402
      .idea/dbnavigator.xml
  4. 54
      .idea/inspectionProfiles/Project_Default.xml
  5. 6
      .idea/inspectionProfiles/profiles_settings.xml
  6. 4
      .idea/misc.xml
  7. 8
      .idea/modules.xml
  8. 6
      .idea/vcs.xml
  9. 124
      TCN.py
  10. 568
      inference_4.py
  11. 22
      upload.html

3
.idea/.gitignore vendored

@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.8 (pytorch_leanning)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

@ -0,0 +1,402 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DBNavigator.Project.DatabaseFileManager">
<open-files />
</component>
<component name="DBNavigator.Project.Settings">
<connections />
<browser-settings>
<general>
<display-mode value="TABBED" />
<navigation-history-size value="100" />
<show-object-details value="false" />
</general>
<filters>
<object-type-filter>
<object-type name="SCHEMA" enabled="true" />
<object-type name="USER" enabled="true" />
<object-type name="ROLE" enabled="true" />
<object-type name="PRIVILEGE" enabled="true" />
<object-type name="CHARSET" enabled="true" />
<object-type name="TABLE" enabled="true" />
<object-type name="VIEW" enabled="true" />
<object-type name="MATERIALIZED_VIEW" enabled="true" />
<object-type name="NESTED_TABLE" enabled="true" />
<object-type name="COLUMN" enabled="true" />
<object-type name="INDEX" enabled="true" />
<object-type name="CONSTRAINT" enabled="true" />
<object-type name="DATASET_TRIGGER" enabled="true" />
<object-type name="DATABASE_TRIGGER" enabled="true" />
<object-type name="SYNONYM" enabled="true" />
<object-type name="SEQUENCE" enabled="true" />
<object-type name="PROCEDURE" enabled="true" />
<object-type name="FUNCTION" enabled="true" />
<object-type name="PACKAGE" enabled="true" />
<object-type name="TYPE" enabled="true" />
<object-type name="TYPE_ATTRIBUTE" enabled="true" />
<object-type name="ARGUMENT" enabled="true" />
<object-type name="DIMENSION" enabled="true" />
<object-type name="CLUSTER" enabled="true" />
<object-type name="DBLINK" enabled="true" />
</object-type-filter>
</filters>
<sorting>
<object-type name="COLUMN" sorting-type="NAME" />
<object-type name="FUNCTION" sorting-type="NAME" />
<object-type name="PROCEDURE" sorting-type="NAME" />
<object-type name="ARGUMENT" sorting-type="POSITION" />
<object-type name="TYPE ATTRIBUTE" sorting-type="POSITION" />
</sorting>
<default-editors>
<object-type name="VIEW" editor-type="SELECTION" />
<object-type name="PACKAGE" editor-type="SELECTION" />
<object-type name="TYPE" editor-type="SELECTION" />
</default-editors>
</browser-settings>
<navigation-settings>
<lookup-filters>
<lookup-objects>
<object-type name="SCHEMA" enabled="true" />
<object-type name="USER" enabled="false" />
<object-type name="ROLE" enabled="false" />
<object-type name="PRIVILEGE" enabled="false" />
<object-type name="CHARSET" enabled="false" />
<object-type name="TABLE" enabled="true" />
<object-type name="VIEW" enabled="true" />
<object-type name="MATERIALIZED VIEW" enabled="true" />
<object-type name="INDEX" enabled="true" />
<object-type name="CONSTRAINT" enabled="true" />
<object-type name="DATASET TRIGGER" enabled="true" />
<object-type name="DATABASE TRIGGER" enabled="true" />
<object-type name="SYNONYM" enabled="false" />
<object-type name="SEQUENCE" enabled="true" />
<object-type name="PROCEDURE" enabled="true" />
<object-type name="FUNCTION" enabled="true" />
<object-type name="PACKAGE" enabled="true" />
<object-type name="TYPE" enabled="true" />
<object-type name="DIMENSION" enabled="false" />
<object-type name="CLUSTER" enabled="false" />
<object-type name="DBLINK" enabled="true" />
</lookup-objects>
<force-database-load value="false" />
<prompt-connection-selection value="true" />
<prompt-schema-selection value="true" />
</lookup-filters>
</navigation-settings>
<dataset-grid-settings>
<general>
<enable-zooming value="true" />
<enable-column-tooltip value="true" />
</general>
<sorting>
<nulls-first value="true" />
<max-sorting-columns value="4" />
</sorting>
<audit-columns>
<column-names value="" />
<visible value="true" />
<editable value="false" />
</audit-columns>
</dataset-grid-settings>
<dataset-editor-settings>
<text-editor-popup>
<active value="false" />
<active-if-empty value="false" />
<data-length-threshold value="100" />
<popup-delay value="1000" />
</text-editor-popup>
<values-actions-popup>
<show-popup-button value="true" />
<element-count-threshold value="1000" />
<data-length-threshold value="250" />
</values-actions-popup>
<general>
<fetch-block-size value="100" />
<fetch-timeout value="30" />
<trim-whitespaces value="true" />
<convert-empty-strings-to-null value="true" />
<select-content-on-cell-edit value="true" />
<large-value-preview-active value="true" />
</general>
<filters>
<prompt-filter-dialog value="true" />
<default-filter-type value="BASIC" />
</filters>
<qualified-text-editor text-length-threshold="300">
<content-types>
<content-type name="Text" enabled="true" />
<content-type name="Properties" enabled="true" />
<content-type name="XML" enabled="true" />
<content-type name="DTD" enabled="true" />
<content-type name="HTML" enabled="true" />
<content-type name="XHTML" enabled="true" />
<content-type name="SQL" enabled="true" />
<content-type name="PL/SQL" enabled="true" />
<content-type name="JSON" enabled="true" />
<content-type name="JSON5" enabled="true" />
<content-type name="YAML" enabled="true" />
</content-types>
</qualified-text-editor>
<record-navigation>
<navigation-target value="VIEWER" />
</record-navigation>
</dataset-editor-settings>
<code-editor-settings>
<general>
<show-object-navigation-gutter value="false" />
<show-spec-declaration-navigation-gutter value="true" />
<enable-spellchecking value="true" />
<enable-reference-spellchecking value="false" />
</general>
<confirmations>
<save-changes value="false" />
<revert-changes value="true" />
<exit-on-changes value="ASK" />
</confirmations>
</code-editor-settings>
<code-completion-settings>
<filters>
<basic-filter>
<filter-element type="RESERVED_WORD" id="keyword" selected="true" />
<filter-element type="RESERVED_WORD" id="function" selected="true" />
<filter-element type="RESERVED_WORD" id="parameter" selected="true" />
<filter-element type="RESERVED_WORD" id="datatype" selected="true" />
<filter-element type="RESERVED_WORD" id="exception" selected="true" />
<filter-element type="OBJECT" id="schema" selected="true" />
<filter-element type="OBJECT" id="role" selected="true" />
<filter-element type="OBJECT" id="user" selected="true" />
<filter-element type="OBJECT" id="privilege" selected="true" />
<user-schema>
<filter-element type="OBJECT" id="table" selected="true" />
<filter-element type="OBJECT" id="view" selected="true" />
<filter-element type="OBJECT" id="materialized view" selected="true" />
<filter-element type="OBJECT" id="index" selected="true" />
<filter-element type="OBJECT" id="constraint" selected="true" />
<filter-element type="OBJECT" id="trigger" selected="true" />
<filter-element type="OBJECT" id="synonym" selected="false" />
<filter-element type="OBJECT" id="sequence" selected="true" />
<filter-element type="OBJECT" id="procedure" selected="true" />
<filter-element type="OBJECT" id="function" selected="true" />
<filter-element type="OBJECT" id="package" selected="true" />
<filter-element type="OBJECT" id="type" selected="true" />
<filter-element type="OBJECT" id="dimension" selected="true" />
<filter-element type="OBJECT" id="cluster" selected="true" />
<filter-element type="OBJECT" id="dblink" selected="true" />
</user-schema>
<public-schema>
<filter-element type="OBJECT" id="table" selected="false" />
<filter-element type="OBJECT" id="view" selected="false" />
<filter-element type="OBJECT" id="materialized view" selected="false" />
<filter-element type="OBJECT" id="index" selected="false" />
<filter-element type="OBJECT" id="constraint" selected="false" />
<filter-element type="OBJECT" id="trigger" selected="false" />
<filter-element type="OBJECT" id="synonym" selected="false" />
<filter-element type="OBJECT" id="sequence" selected="false" />
<filter-element type="OBJECT" id="procedure" selected="false" />
<filter-element type="OBJECT" id="function" selected="false" />
<filter-element type="OBJECT" id="package" selected="false" />
<filter-element type="OBJECT" id="type" selected="false" />
<filter-element type="OBJECT" id="dimension" selected="false" />
<filter-element type="OBJECT" id="cluster" selected="false" />
<filter-element type="OBJECT" id="dblink" selected="false" />
</public-schema>
<any-schema>
<filter-element type="OBJECT" id="table" selected="true" />
<filter-element type="OBJECT" id="view" selected="true" />
<filter-element type="OBJECT" id="materialized view" selected="true" />
<filter-element type="OBJECT" id="index" selected="true" />
<filter-element type="OBJECT" id="constraint" selected="true" />
<filter-element type="OBJECT" id="trigger" selected="true" />
<filter-element type="OBJECT" id="synonym" selected="true" />
<filter-element type="OBJECT" id="sequence" selected="true" />
<filter-element type="OBJECT" id="procedure" selected="true" />
<filter-element type="OBJECT" id="function" selected="true" />
<filter-element type="OBJECT" id="package" selected="true" />
<filter-element type="OBJECT" id="type" selected="true" />
<filter-element type="OBJECT" id="dimension" selected="true" />
<filter-element type="OBJECT" id="cluster" selected="true" />
<filter-element type="OBJECT" id="dblink" selected="true" />
</any-schema>
</basic-filter>
<extended-filter>
<filter-element type="RESERVED_WORD" id="keyword" selected="true" />
<filter-element type="RESERVED_WORD" id="function" selected="true" />
<filter-element type="RESERVED_WORD" id="parameter" selected="true" />
<filter-element type="RESERVED_WORD" id="datatype" selected="true" />
<filter-element type="RESERVED_WORD" id="exception" selected="true" />
<filter-element type="OBJECT" id="schema" selected="true" />
<filter-element type="OBJECT" id="user" selected="true" />
<filter-element type="OBJECT" id="role" selected="true" />
<filter-element type="OBJECT" id="privilege" selected="true" />
<user-schema>
<filter-element type="OBJECT" id="table" selected="true" />
<filter-element type="OBJECT" id="view" selected="true" />
<filter-element type="OBJECT" id="materialized view" selected="true" />
<filter-element type="OBJECT" id="index" selected="true" />
<filter-element type="OBJECT" id="constraint" selected="true" />
<filter-element type="OBJECT" id="trigger" selected="true" />
<filter-element type="OBJECT" id="synonym" selected="true" />
<filter-element type="OBJECT" id="sequence" selected="true" />
<filter-element type="OBJECT" id="procedure" selected="true" />
<filter-element type="OBJECT" id="function" selected="true" />
<filter-element type="OBJECT" id="package" selected="true" />
<filter-element type="OBJECT" id="type" selected="true" />
<filter-element type="OBJECT" id="dimension" selected="true" />
<filter-element type="OBJECT" id="cluster" selected="true" />
<filter-element type="OBJECT" id="dblink" selected="true" />
</user-schema>
<public-schema>
<filter-element type="OBJECT" id="table" selected="true" />
<filter-element type="OBJECT" id="view" selected="true" />
<filter-element type="OBJECT" id="materialized view" selected="true" />
<filter-element type="OBJECT" id="index" selected="true" />
<filter-element type="OBJECT" id="constraint" selected="true" />
<filter-element type="OBJECT" id="trigger" selected="true" />
<filter-element type="OBJECT" id="synonym" selected="true" />
<filter-element type="OBJECT" id="sequence" selected="true" />
<filter-element type="OBJECT" id="procedure" selected="true" />
<filter-element type="OBJECT" id="function" selected="true" />
<filter-element type="OBJECT" id="package" selected="true" />
<filter-element type="OBJECT" id="type" selected="true" />
<filter-element type="OBJECT" id="dimension" selected="true" />
<filter-element type="OBJECT" id="cluster" selected="true" />
<filter-element type="OBJECT" id="dblink" selected="true" />
</public-schema>
<any-schema>
<filter-element type="OBJECT" id="table" selected="true" />
<filter-element type="OBJECT" id="view" selected="true" />
<filter-element type="OBJECT" id="materialized view" selected="true" />
<filter-element type="OBJECT" id="index" selected="true" />
<filter-element type="OBJECT" id="constraint" selected="true" />
<filter-element type="OBJECT" id="trigger" selected="true" />
<filter-element type="OBJECT" id="synonym" selected="true" />
<filter-element type="OBJECT" id="sequence" selected="true" />
<filter-element type="OBJECT" id="procedure" selected="true" />
<filter-element type="OBJECT" id="function" selected="true" />
<filter-element type="OBJECT" id="package" selected="true" />
<filter-element type="OBJECT" id="type" selected="true" />
<filter-element type="OBJECT" id="dimension" selected="true" />
<filter-element type="OBJECT" id="cluster" selected="true" />
<filter-element type="OBJECT" id="dblink" selected="true" />
</any-schema>
</extended-filter>
</filters>
<sorting enabled="true">
<sorting-element type="RESERVED_WORD" id="keyword" />
<sorting-element type="RESERVED_WORD" id="datatype" />
<sorting-element type="OBJECT" id="column" />
<sorting-element type="OBJECT" id="table" />
<sorting-element type="OBJECT" id="view" />
<sorting-element type="OBJECT" id="materialized view" />
<sorting-element type="OBJECT" id="index" />
<sorting-element type="OBJECT" id="constraint" />
<sorting-element type="OBJECT" id="trigger" />
<sorting-element type="OBJECT" id="synonym" />
<sorting-element type="OBJECT" id="sequence" />
<sorting-element type="OBJECT" id="procedure" />
<sorting-element type="OBJECT" id="function" />
<sorting-element type="OBJECT" id="package" />
<sorting-element type="OBJECT" id="type" />
<sorting-element type="OBJECT" id="dimension" />
<sorting-element type="OBJECT" id="cluster" />
<sorting-element type="OBJECT" id="dblink" />
<sorting-element type="OBJECT" id="schema" />
<sorting-element type="OBJECT" id="role" />
<sorting-element type="OBJECT" id="user" />
<sorting-element type="RESERVED_WORD" id="function" />
<sorting-element type="RESERVED_WORD" id="parameter" />
</sorting>
<format>
<enforce-code-style-case value="true" />
</format>
</code-completion-settings>
<execution-engine-settings>
<statement-execution>
<fetch-block-size value="100" />
<execution-timeout value="20" />
<debug-execution-timeout value="600" />
<focus-result value="false" />
<prompt-execution value="false" />
</statement-execution>
<script-execution>
<command-line-interfaces />
<execution-timeout value="300" />
</script-execution>
<method-execution>
<execution-timeout value="30" />
<debug-execution-timeout value="600" />
<parameter-history-size value="10" />
</method-execution>
</execution-engine-settings>
<operation-settings>
<transactions>
<uncommitted-changes>
<on-project-close value="ASK" />
<on-disconnect value="ASK" />
<on-autocommit-toggle value="ASK" />
</uncommitted-changes>
<multiple-uncommitted-changes>
<on-commit value="ASK" />
<on-rollback value="ASK" />
</multiple-uncommitted-changes>
</transactions>
<session-browser>
<disconnect-session value="ASK" />
<kill-session value="ASK" />
<reload-on-filter-change value="false" />
</session-browser>
<compiler>
<compile-type value="KEEP" />
<compile-dependencies value="ASK" />
<always-show-controls value="false" />
</compiler>
<debugger>
<debugger-type value="JDBC" />
</debugger>
</operation-settings>
<ddl-file-settings>
<extensions>
<mapping file-type-id="VIEW" extensions="vw" />
<mapping file-type-id="TRIGGER" extensions="trg" />
<mapping file-type-id="PROCEDURE" extensions="prc" />
<mapping file-type-id="FUNCTION" extensions="fnc" />
<mapping file-type-id="PACKAGE" extensions="pkg" />
<mapping file-type-id="PACKAGE_SPEC" extensions="pks" />
<mapping file-type-id="PACKAGE_BODY" extensions="pkb" />
<mapping file-type-id="TYPE" extensions="tpe" />
<mapping file-type-id="TYPE_SPEC" extensions="tps" />
<mapping file-type-id="TYPE_BODY" extensions="tpb" />
</extensions>
<general>
<lookup-ddl-files value="true" />
<create-ddl-files value="false" />
<synchronize-ddl-files value="true" />
<use-qualified-names value="false" />
<make-scripts-rerunnable value="true" />
</general>
</ddl-file-settings>
<general-settings>
<regional-settings>
<date-format value="MEDIUM" />
<number-format value="UNGROUPED" />
<locale value="SYSTEM_DEFAULT" />
<use-custom-formats value="false" />
</regional-settings>
<environment>
<environment-types>
<environment-type id="development" name="Development" description="Development environment" color="-2430209/-12296320" readonly-code="false" readonly-data="false" />
<environment-type id="integration" name="Integration" description="Integration environment" color="-2621494/-12163514" readonly-code="true" readonly-data="false" />
<environment-type id="production" name="Production" description="Productive environment" color="-11574/-10271420" readonly-code="true" readonly-data="true" />
<environment-type id="other" name="Other" description="" color="-1576/-10724543" readonly-code="false" readonly-data="false" />
</environment-types>
<visibility-settings>
<connection-tabs value="true" />
<dialog-headers value="true" />
<object-editor-tabs value="true" />
<script-editor-tabs value="false" />
<execution-result-tabs value="true" />
</visibility-settings>
</environment>
</general-settings>
</component>
</project>

@ -0,0 +1,54 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="33">
<item index="0" class="java.lang.String" itemvalue="PyAutoGUI" />
<item index="1" class="java.lang.String" itemvalue="rlcard" />
<item index="2" class="java.lang.String" itemvalue="PyQt5" />
<item index="3" class="java.lang.String" itemvalue="torch" />
<item index="4" class="java.lang.String" itemvalue="pandas" />
<item index="5" class="java.lang.String" itemvalue="matplotlib" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="einops" />
<item index="8" class="java.lang.String" itemvalue="cpm_kernels" />
<item index="9" class="java.lang.String" itemvalue="sentencepiece" />
<item index="10" class="java.lang.String" itemvalue="xformers" />
<item index="11" class="java.lang.String" itemvalue="streamlit" />
<item index="12" class="java.lang.String" itemvalue="transformers_stream_generator" />
<item index="13" class="java.lang.String" itemvalue="accelerate" />
<item index="14" class="java.lang.String" itemvalue="bitsandbytes" />
<item index="15" class="java.lang.String" itemvalue="yacs" />
<item index="16" class="java.lang.String" itemvalue="tqdm" />
<item index="17" class="java.lang.String" itemvalue="mmdet" />
<item index="18" class="java.lang.String" itemvalue="termcolor" />
<item index="19" class="java.lang.String" itemvalue="mmsegmentation" />
<item index="20" class="java.lang.String" itemvalue="timm" />
<item index="21" class="java.lang.String" itemvalue="fvcore" />
<item index="22" class="java.lang.String" itemvalue="pynvml" />
<item index="23" class="java.lang.String" itemvalue="pytz" />
<item index="24" class="java.lang.String" itemvalue="mmcv-full" />
<item index="25" class="java.lang.String" itemvalue="jinja2" />
<item index="26" class="java.lang.String" itemvalue="chardet" />
<item index="27" class="java.lang.String" itemvalue="bs4" />
<item index="28" class="java.lang.String" itemvalue="tornado" />
<item index="29" class="java.lang.String" itemvalue="social-auth-storage-sqlalchemy" />
<item index="30" class="java.lang.String" itemvalue="social-auth-app-tornado" />
<item index="31" class="java.lang.String" itemvalue="social-auth-core" />
<item index="32" class="java.lang.String" itemvalue="scikit_learn" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="N801" />
<option value="N806" />
</list>
</option>
</inspection_tool>
</profile>
</component>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch_leanning)" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/AIShuiwu.iml" filepath="$PROJECT_DIR$/.idea/AIShuiwu.iml" />
</modules>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

124
TCN.py

@ -0,0 +1,124 @@
"""
Author陆绍超
Project name:swDLiner_3
Created on 2024/05/10 上午11:45
"""
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
其实这就是一个裁剪的模块裁剪多出来的padding
"""
return x[:, :, :-self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
"""
相当于一个Residual block
:param n_inputs: int, 输入通道数
:param n_outputs: int, 输出通道数
:param kernel_size: int, 卷积核尺寸
:param stride: int, 步长一般为1
:param dilation: int, 膨胀系数
:param padding: int, 填充系数
:param dropout: float, dropout比率
"""
super(TemporalBlock, self).__init__()
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
# 经过conv1,输出的size其实是(Batch, input_channel, seq_len + padding)
self.chomp1 = Chomp1d(padding) # 裁剪掉多出来的padding部分,维持输出时间步为seq_len
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding) # 裁剪掉多出来的padding部分,维持输出时间步为seq_len
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
"""
参数初始化
:return:
"""
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
"""
:param x: size of (Batch, input_channel, seq_len)
:return:
"""
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)
class TemporalConvNet(nn.Module):
def __init__(self, seq_len, pred_len, num_inputs, num_channels, kernel_size=2, dropout=0.2):
"""
TCN目前paper给出的TCN结构很好的支持每个时刻为一个数的情况即sequence结构
对于每个时刻为一个向量这种一维结构勉强可以把向量拆成若干该时刻的输入通道
对于每个时刻为一个矩阵或更高维图像的情况就不太好办
:param num_inputs: int 输入通道数
:param num_channels: list每层的hidden_channel数例如[25,25,25,25]表示有4个隐层每层hidden_channel数为25
:param kernel_size: int, 卷积核尺寸
:param dropout: float, drop_out比率
"""
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i # 膨胀系数:1,2,4,8……
in_channels = num_inputs if i == 0 else num_channels[i - 1] # 确定每一层的输入通道数
out_channels = num_channels[i] # 确定每一层的输出通道数
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size - 1) * dilation_size, dropout=dropout)]
self.network = nn.Sequential(*layers)
self.mlp = nn.Linear(seq_len, pred_len)
def forward(self, x):
"""
输入x的结构不同于RNN一般RNN的size为(Batch, seq_len, channels)或者(seq_len, Batch, channels)
这里把seq_len放在channels后面把所有时间步的数据拼起来当做Conv1d的输入尺寸实现卷积跨时间步的操作
很巧妙的设计
:param x: size of (Batch, seq_len,input_channel)
:return: size of (Batch, seq_len, output_channel)
"""
x = x.permute(0, 2, 1)
x = self.network(x)
x = self.mlp(x)
x = x.permute(0, 2, 1)
return x
if __name__ == "__main__":
import argparse
x = torch.randn([2, 120, 25])
model_net = TemporalConvNet(seq_len=120, pred_len=60,num_inputs=32, num_channels=[25, 30, 35, 35, 30, 25])
pred = model_net(x)
print(pred)
print(pred.size())

@ -0,0 +1,568 @@
"""
Author陆绍超
Project name:swDLiner
Created on 2024/05/07 下午1:20
"""
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from TCN import TemporalConvNet
import json
import tornado.web
from tornado.escape import json_decode
from tornado.log import LogFormatter
import logging
from datetime import datetime
class NormalizedScaler:
def __init__(self):
self.min_value = 0.
self.max_value = 1.0
self.target_column_indices = None
def fit(self, data):
self.min_value = data.min(0)
self.max_value = data.max(0)
# 计算最小值和最大值
self.maxmin_zeros = ((self.max_value - self.min_value) <= 1e-2)
# print(self.maxmin_zeros)
def transform(self, data):
max_value = torch.from_numpy(self.max_value).type_as(data).to(data.device) if torch.is_tensor(
data) else self.max_value
min_value = torch.from_numpy(self.min_value).type_as(data).to(data.device) if torch.is_tensor(
data) else self.min_value
if any(self.maxmin_zeros):
normalized_data = torch.zeros_like(data) if torch.is_tensor(data) else np.zeros_like(data)
# 对每一列进行归一化,除非该列的最大值和最小值相等
for col in range(data.shape[1]):
if not self.maxmin_zeros[col]:
normalized_data[:, col] = (data[:, col] - min_value[col]) / (max_value[col] - min_value[col])
else:
normalized_data = (data - min_value) / (max_value - min_value)
return normalized_data
def y_transform(self, data):
max_value = torch.from_numpy(self.max_value[self.target_column_indices]).type_as(data).to(
data.device) if torch.is_tensor(
data) else self.max_value[self.target_column_indices]
min_value = torch.from_numpy(self.min_value[self.target_column_indices]).type_as(data).to(
data.device) if torch.is_tensor(
data) else self.min_value[self.target_column_indices]
maxmin_zeros = self.maxmin_zeros[self.target_column_indices]
if any(self.maxmin_zeros):
normalized_data = torch.zeros_like(data) if torch.is_tensor(data) else np.zeros_like(data)
# 对每一列进行归一化,除非该列的最大值和最小值相等
for col in range(data.shape[1]):
if not maxmin_zeros[col]:
normalized_data[:, col] = (data[:, col] - min_value[col]) / (max_value[col] - min_value[col])
else:
normalized_data = (data - min_value) / (max_value - min_value)
return normalized_data
def inverse_transform(self, data):
max_value = torch.from_numpy(self.max_value[self.target_column_indices]).type_as(data).to(
data.device) if torch.is_tensor(
data) else self.max_value[self.target_column_indices]
min_value = torch.from_numpy(self.min_value[self.target_column_indices]).type_as(data).to(
data.device) if torch.is_tensor(
data) else self.min_value[self.target_column_indices]
return (data * (max_value - min_value)) + min_value
class Dataset_GUISANLI_minute(Dataset):
def __init__(self, size=None, target=None, column_order=None, scale=True):
if target is None:
self.target = ['Do', 'outCod', 'outNH3N', 'outPh', 'outTN', 'outTP']
else:
self.target = target
if column_order is None:
# 列名列表,按照这个顺序排列
self.column_order = ['Do', 'Do1', 'Do2', 'inCod', 'inFlow', 'inNH3N', 'inPh',
'outCod', 'outFlow', 'outFlowNow', 'outNH3N', 'outPh',
'outTN', 'outTP', 'yw_bz', 'yw_mc1', 'yw_mc2', 'yw_tj2']
else:
self.column_order = column_order
if size is None:
self.seq_len = 120
self.pred_len = 60
else:
self.seq_len = size[0]
self.pred_len = size[1]
self.scale = scale
self.scaler = NormalizedScaler()
self.df_raw = None
def read_data(self, df_raw):
self.df_raw = df_raw
'''
df_raw.columns: ['date', ...(other features), target feature]
'''
if not all(column in df_raw.columns for column in self.column_order):
print(f"DataFrame must contain columns: {self.column_order}")
# 使用reindex方法按照列名列表对列进行排列
df_data = df_raw[self.column_order]
self.data_x = df_data
self.data_y = df_data
if self.target:
self.data_y = self.data_y[self.target]
if self.scale:
# 获取列名对应的列索引列表,给反标准化做准备
column_indices_1 = [self.data_x.columns.get_loc(col) for col in self.target]
self.scaler.target_column_indices = column_indices_1
def __getitem__(self, index):
s_begin = index
s_end = s_begin + self.seq_len
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data_x[s_begin:s_end]
seq_y = self.data_y[r_begin:r_end]
if self.scale:
self.scaler.fit(seq_x.values)
x_data = self.scaler.transform(seq_x.values)
y_data = self.scaler.y_transform(seq_y.values)
return torch.from_numpy(x_data).to(torch.float32), torch.from_numpy(y_data).to(torch.float32)
def __len__(self):
return len(self.data_x) - self.seq_len - self.pred_len + 1
def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
class Pred_GUISANLI_minute():
def __init__(self, size=None, target=None, column_order=None, scale=True, sn=None):
if column_order is None:
self.column_order = ['Do', 'Do1', 'Do2', 'inCod', 'inFlow', 'inNH3N', 'inPh',
'outCod', 'outFlow', 'outFlowNow', 'outNH3N', 'outPh',
'outTN', 'outTP', 'yw_bz', 'yw_mc1', 'yw_mc2', 'yw_tj2']
else:
self.column_order = column_order
if target is None:
self.target = ['Do', 'outCod', 'outNH3N', 'outPh', 'outTN', 'outTP'] # 6
else:
self.target = target
if size is None:
self.seq_len = 120
self.pred_len = 60
else:
self.seq_len = size[0]
self.pred_len = size[1]
self.scale = scale
self.scaler = NormalizedScaler()
self.sn = sn
self.df_raw = None
def get_df_raw(self, df_raw):
self.df_raw = df_raw
'''
df_raw.columns: ['date', ...(other features), target feature]
'''
# 列名列表,按照这个顺序排列
if not all(column in self.df_raw.columns for column in self.column_order):
print(f"DataFrame must contain columns: {self.column_order}")
def __getitem__(self, index):
self.data_x = self.df_raw[self.column_order] # 预测数据
self.data_date = self.df_raw['date'] # 时间数据
if self.scale:
# 获取列名对应的列索引列表,给反标准化做准备
column_indices_1 = [self.data_x.columns.get_loc(col) for col in self.target]
self.scaler.target_column_indices = column_indices_1
# 使用reindex方法按照列名列表对列进行排列
s_begin = len(self.data_x) - self.seq_len - index
s_end = s_begin + self.seq_len
seq_date = self.data_date[s_begin:s_end]
seq_x = self.data_x[s_begin:s_end]
if self.scale:
# 测试是否为数据的部分,已经为测试标签联合测试
# print('==start' * 20)
# print(seq_x)
# print('==end' * 20)
self.scaler.fit(seq_x.values)
x_data = self.scaler.transform(seq_x.values)
return seq_date.values, torch.from_numpy(x_data).to(torch.float32)
def __len__(self):
if self.df_raw is None:
return 0
elif (len(self.df_raw) - self.seq_len + 1) < 0:
return 0
else:
return len(self.df_raw) - self.seq_len + 1
def inverse_transform(self, data):
return self.scaler.inverse_transform(data)
def load_model(weights_path, num_inputs=32, num_outputs=6):
predict_model = TemporalConvNet(seq_len=120,
pred_len=60,
num_inputs=num_inputs,
num_channels=[64, 128, 256, 128, 64, 32, num_outputs]) # 加载模型
if os.path.exists(weights_path):
model_weights = torch.load(weights_path) # 读取权重文件
predict_model.load_state_dict(model_weights) # 模型加载权重
else:
print("模型权重不存在")
return predict_model
def config_init():
# 从文件中读取JSON并转换回字典
config_load = {
'20210225GUISANLI': {'model': './Upload/GUIGWULI/TCN_weights_GUIGWULIm1.pth',
'data_loader': '20210225GUISANLI',
'SN': '20210225GUISANLI',
'target': ['Do', 'outCod', 'outNH3N', 'outPh', 'outTN', 'outTP'],
'columns': ['Do', 'Do1', 'Do2', 'inCod', 'inFlowNow', 'inNH3N', 'inPh', 'outCod',
'outFlowNow', 'outNH3N', 'outPh', 'outTN', 'outTP', 'yw_bz', 'yw_mc1',
'yw_mc2', 'yw_tj2'],
},
'20210207GUIGWULI': {'model': './Upload/GUIGWULI/TCN_weights_GUIGWULIm1.pth',
'data_loader': '20210207GUIGWULI',
'SN': '20210207GUIGWULI',
'target': ['Do', 'outCod', 'outNH3N', 'outPh', 'outTN', 'outTP'],
'columns': ['Do', 'inCod', 'inFlowNow', 'inNH3N', 'inPh', 'outCod', 'outFlowNow',
'outNH3N', 'outPh', 'outTN', 'outTP', 'yw_bz', 'yw_mc1', 'yw_mc2', 'yw_tj1'],
},
'20210309ZHANGMUZ': {'model': './Upload/GUIGWULI/TCN_weights_GUIGWULIm1.pth',
'data_loader': '20210309ZHANGMUZ',
'SN': '20210309ZHANGMUZ',
'target': ['outCOD', 'outNH3N', 'outPH', 'outTN', 'outTP'],
'columns': ['inCOD', 'inFlowNow', 'inNH3N', 'inPH', 'outCOD', 'outFlowNow', 'outNH3N',
'outPH', 'outTN', 'outTP', 'yw_bz', 'yw_mc1', 'yw_mc2', 'yw_mc3', 'yw_mc4',
'yw_tj1', 'yw_tj2', 'yw_tj3', 'yw_tj4']
},
}
# with open('config', 'r', encoding='utf-8') as f:
# config_load = json.load(f)
# config_load = dict(config_load)
configs = {}
for key, val in config_load.items():
config_item = {}
for k, v in val.items():
if k == 'model':
config_item[k] = load_model(weights_path=v,
num_inputs=len(val.get('columns', [])),
num_outputs=len(val.get('target', [])))
elif k == 'data_loader':
config_item[k] = Pred_GUISANLI_minute(sn=v,
target=val.get('target', None),
column_order=val.get('columns', None))
elif k == 'SN':
config_item[k] = v
elif k == 'target':
config_item[k] = v
elif k == 'columns':
config_item[k] = v
else:
raise ValueError("配置错误")
configs[key] = config_item
return configs
configs = config_init()
def pseudo_model_predict(model, pred_data_loader):
# 尝试从pred_data_loader加载预测数据
if len(pred_data_loader) > 0:
# 假设pred_data_loader是一个列表,并且至少有一个元素
date, predict_data = pred_data_loader[0]
else:
return {}
try:
# 将预测数据转换为一个批次,在PyTorch中,每个批次至少需要有一个样本
predict_data = torch.unsqueeze(predict_data, 0) # 第0维加入batch维度
# 确保模型处于评估模式
model.eval()
# 使用模型进行推理
predict_result = model(predict_data)
# 删除batch维度
predict_result = torch.squeeze(predict_result, 0)
# 对预测结果进行后处理
predict_result = pred_data_loader.inverse_transform(predict_result)
# 确保预测结果是一个numpy数组
predict_result = predict_result.detach().numpy()
# 创建一个时间序列索引
start_time = pd.Timestamp(date[-1])
date = pd.date_range(start=start_time + pd.Timedelta(minutes=1), periods=len(predict_result), freq='T')
# 创建一个DataFrame,将时间序列索引作为列
df = pd.DataFrame(date, columns=['date'])
# 标题行列表
target_headers = ['outCod', 'outTN', 'outNH3N', 'outTP', 'outPh', 'Do']
# 将时间序列索引设置为DataFrame的索引
df[target_headers] = predict_result
print(df)
# 将DataFrame转换为JSON格式的字符串
json_str = df.to_json(orient='records')
print(json_str)
except Exception as e: # 使用异常捕获来处理可能出现的任何异常
# 记录错误信息
print(f"An error occurred: {e}")
# 返回一个空的字典作为JSON字符串
json_str = {}
return json_str
# 模型预测请求
class PredictHandler(tornado.web.RequestHandler):
def get(self, keyword):
if keyword in configs.keys():
json_str = pseudo_model_predict(configs[keyword]['model'], configs[keyword]['data_loader'])
# 构造响应数据
response = {"prediction": json_str}
# 设置响应的Content-Type为application/json
self.set_header("Content-Type", "application/json")
# 将结果返回给客户端
self.write(json.dumps(response))
else:
self.write("Unknown keyword.")
# 模型上传请求
class UploadHandler(tornado.web.RequestHandler):
def post(self):
# 获取表单字段
group_name = self.get_body_argument('groupName')
model_file = self.request.files['modelFile'][0]
csv_file = self.request.files['csvTable'][0]
# 创建组别目录
save_path = os.path.join('./Upload', group_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
# 保存模型文件
model_filename = model_file.filename
model_path = os.path.join(save_path, model_filename)
with open(model_path, 'wb') as f:
f.write(model_file.body)
# 保存CSV文件
csv_filename = csv_file.filename
csv_path = os.path.join(save_path, csv_filename)
with open(csv_path, 'wb') as f:
f.write(csv_file.body)
self.write(f'Files for group "{group_name}" have been uploaded and saved successfully.')
async def train(data_set, predict_model, pth_save_name):
print('模型训练开始')
random_seed = 240510 # set a random seed for reproducibility
np.random.seed(random_seed)
torch.manual_seed(random_seed)
# prep_dataloader 函数 将数据拆分成训练集与验证集。 并载入dataloader
train_dataloader = DataLoader(
data_set,
batch_size=16,
shuffle=True,
num_workers=0,
drop_last=False)
loss_function = nn.MSELoss() # 采用MSE为回归的损失函数
optimizer = torch.optim.Adam(predict_model.parameters(), lr=0.0001) # 采用Adam优化器
epochs = 4 # 迭代epoch次数
train_epoch_loss = [] # 记录每个训练epoch的平均损失
for epoch in range(epochs):
# train --------------------------------------------------------------------------------------------------
predict_model.train()
train_step_loss = []
for step, data in enumerate(train_dataloader):
sample, label = data
optimizer.zero_grad() # 清空梯度,pytorch默认梯度会保留累加
pre = predict_model(sample)
loss = loss_function(pre, label)
loss.backward()
optimizer.step()
train_step_loss.append(loss.item())
train_average_loss = sum(train_step_loss) / len(train_step_loss)
train_epoch_loss.append(train_average_loss)
print(f"[在第{epoch + 1:}个epoch,训练的]: train_epoch_loss = {train_average_loss:.4f}")
torch.save(predict_model.state_dict(), pth_save_name)
print('模型训练完成')
return predict_model
# ==========================================
# 定时获取数据
# ==========================================
http_client = tornado.httpclient.AsyncHTTPClient()
async def generate_data():
global http_client
global configs
try:
for k1, v1 in configs.items():
SN = v1['SN']
# 请求头
headers = {
'Authority': 'iot.gxghzh.com:8888',
'Method': 'POST',
'Path': '/exeCmd',
'Scheme': 'https',
'Accept': 'application/json, text/plain, */*',
'Accept-Encoding': 'gzip, deflate, br, zstd',
'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
'Cmd': 'GetAllConfigs',
'Content-Length': '2',
'Content-Type': 'application/json;charset=UTF-8',
'Origin': 'http://127.0.0.1:6810',
'Priority': 'u=1, i',
'Referer': 'http://127.0.0.1:6810/',
'Sec-Ch-Ua': '"Chromium";v="124", "Microsoft Edge";v="124", "Not-A.Brand";v="99"',
'Sec-Ch-Ua-Mobile': '?0',
'Sec-Ch-Ua-Platform': '"Windows"',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'cross-site',
'Sn': SN,
'Token': '45a73a59b3d23545',
'Uid': '0',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/124.0.0.0 '
'Safari/537.36 Edg/124.0.0.0 '
}
# 构造POST请求的URL和参数
response = await http_client.fetch("https://iot.gxghzh.com:8888/exeCmd", method="POST", headers=headers,
body=json.dumps({}))
# 检查响应状态码
if response.code == 200:
# 解析响应数据(假设是JSON格式)
response_data = response.body.decode('utf-8')
# 将字符串解析为 JSON 对象
response_data = json.loads(response_data)
SwitchTerminals = response_data['Result']['SwitchTerminals']
AnalogTerminals = response_data['Result']['AnalogTerminals']
# 获取当前时间并格式化为字符串
current_time = datetime.now()
item_dict = {'date': current_time} # 为数据添加现在的时间
item_key_list_1 = []
item_key_list_2 = []
for child in AnalogTerminals:
key = child['key']
value = child['value']
item_dict[key] = value
item_key_list_1.append(key)
for child in SwitchTerminals:
key = child['key']
value = child['value']
item_dict[key] = value
item_key_list_2.append(key)
if v1['data_loader'].df_raw is None:
# 第一次创建df
# 获取当前日期
current_date = datetime.now().strftime("%Y%m%d")
# 构建文件名模式
file_pattern = f"./{current_date}_{SN}.csv"
# 判断当前文件夹是否存在该文件
file_exists = os.path.exists(file_pattern)
if file_exists:
v1['data_loader'].get_df_raw(pd.read_csv(file_pattern, parse_dates=True))
else:
item_key_list_1 = sorted(item_key_list_1)
item_key_list_2 = sorted(item_key_list_2)
sort_list = ['date'] + item_key_list_1 + item_key_list_2
print(sort_list, len(sort_list))
v1['data_loader'].get_df_raw(pd.DataFrame(columns=sort_list))
# 使用concat方法添加新行
v1['data_loader'].df_raw = pd.concat([v1['data_loader'].df_raw, pd.DataFrame([item_dict])],
ignore_index=True)
# json_str_GUISANLI = pseudo_model_predict(v1['model'], v1['data_loader'])
print(f'请求成功,状态码:{response.code}')
else:
print(f'请求失败,状态码:{response.code}')
print("============ 每隔一分钟展示df ====================")
print(v1['data_loader'].df_raw.tail()) # 每隔一分钟展示df
print(f"shape:{v1['data_loader'].df_raw.shape}")
print("===============================================")
# 保存到文件以防止
if len(v1['data_loader'].df_raw) % 10 == 0:
# 获取当前日期
current_date = datetime.now().strftime("%Y%m%d")
# 构建文件名模式
file_pattern = f"./{current_date}_{SN}.csv"
data_set = Dataset_GUISANLI_minute(target=v1.get('target', None), column_order=v1.get('columns', None))
data_set.read_data(v1['data_loader'].df_raw)
predict_model = v1['model']
configs[k1]['model'] = await train(data_set=data_set,
predict_model=predict_model,
pth_save_name=f"./{current_date}_{SN}_TCN.pth")
v1['data_loader'].df_raw.to_csv(file_pattern, index=False)
# df_new = v1['data_loader'].df_raw.iloc[-120:, :].copy()
# del v1['data_loader'].df_raw
# v1['data_loader'].df_raw = df_new
# v1['data_loader'].df_raw.reset_index(drop=True, inplace=True)
except tornado.httpclient.HTTPError as e:
print("HTTP Error:", e)
except Exception as e:
print("Exception:", e)
finally:
http_client.close()
# 创建Tornado应用
app = tornado.web.Application([
(r"/predict/(\w+)", PredictHandler),
(r"/upload", UploadHandler),
])
# 配置日志
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)
#
# formatter = LogFormatter(
# fmt='%(color)s[%(asctime)s] %(levelname)s - %(message)s%(end_color)s',
# datefmt='%Y-%m-%d %H:%M:%S'
# )
#
# # 设置日志文件
# file_handler = logging.FileHandler("tornado.log")
# file_handler.setFormatter(formatter)
# logger.addHandler(file_handler)
if __name__ == "__main__":
# 启动服务器
app.listen(8886)
print("Server started on port 8886")
# 每隔60秒调用一次generate_data函数
tornado.ioloop.PeriodicCallback(generate_data, 60000).start()
tornado.ioloop.IOLoop.current().start()

@ -0,0 +1,22 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>模型和CSV文件上传</title>
</head>
<body>
<h1>上传模型和CSV文件</h1>
<form action="/upload" method="post" enctype="multipart/form-data">
<label for="groupName">组别名称:</label>
<input type="text" id="groupName" name="groupName" required><br><br>
<label for="modelFile">模型文件:</label>
<input type="file" id="modelFile" name="modelFile" required><br><br>
<label for="csvTable">CSV表:</label>
<input type="file" id="csvTable" name="csvTable" accept=".csv" required><br><br>
<input type="submit" value="上传">
</form>
</body>
</html>
Loading…
Cancel
Save