Skip to content

Commit

Permalink
syntax cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
leoglonz committed Nov 17, 2024
1 parent 965ce9a commit c14de9b
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 102 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,7 @@ cython_debug/


# VCS versioning
src/hydroDL2/_version.py
src/hydroDL2/_version.py

# Other
*sacsma*
63 changes: 32 additions & 31 deletions src/hydroDL2/models/hbv/hbv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hydroDL2.core.calc.uh_routing import UH_conv, UH_gamma



class HBV(torch.nn.Module):
"""Multi-component Pytorch HBV model.
Expand All @@ -13,7 +14,7 @@ class HBV(torch.nn.Module):
which runs the HBV-light hydrological model (Seibert, 2005).
"""
def __init__(self, config=None, device=None):
super(HBV, self).__init__()
super().__init__()
self.config = config
self.initialize = False
self.warm_up = 0
Expand All @@ -25,20 +26,20 @@ def __init__(self, config=None, device=None):
self.comprout = False
self.nearzero = 1e-5
self.nmul = 1
self.parameter_bounds = dict(
parBETA=[1.0, 6.0],
parFC=[50, 1000],
parK0=[0.05, 0.9],
parK1=[0.01, 0.5],
parK2=[0.001, 0.2],
parLP=[0.2, 1],
parPERC=[0, 10],
parUZL=[0, 100],
parTT=[-2.5, 2.5],
parCFMAX=[0.5, 10],
parCFR=[0, 0.1],
parCWH=[0, 0.2]
)
self.parameter_bounds = {
parBETA: [1.0, 6.0],
parFC: [50, 1000],
parK0: [0.05, 0.9],
parK1: [0.01, 0.5],
parK2: [0.001, 0.2],
parLP: [0.2, 1],
parPERC: [0, 10],
parUZL: [0, 100],
parTT: [-2.5, 2.5],
parCFMAX: [0.5, 10],
parCFR: [0, 0.1],
parCWH: [0, 0.2]
}
self.conv_routing_hydro_model_bound = [
[0, 2.9], # routing parameter a
[0, 6.5] # routing parameter b
Expand All @@ -57,12 +58,12 @@ def __init__(self, config=None, device=None):
self.nearzero = config['phy_model']['nearzero']
self.nmul = config['nmul']

if 'parBETAET' in config['phy_model']['dy_params']['HBV']:
if 'parBETAET' in self.dy_params :
self.parameter_bounds['parBETAET'] = [0.3, 5]

def forward(self, x, parameters, routing_parameters=None, muwts=None,
comprout=False):
"""Forward pass for HBV"""
"""Forward pass for HBV."""
# Initialization
if self.warm_up > 0:
with torch.no_grad():
Expand All @@ -87,20 +88,20 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
)
else:
# Without warm-up, initialize state variables with zeros.
Ngrid = x.shape[1]
SNOWPACK = torch.zeros([Ngrid, self.nmul],
n_grid = x.shape[1]
SNOWPACK = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
MELTWATER = torch.zeros([Ngrid, self.nmul],
MELTWATER = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SM = torch.zeros([Ngrid, self.nmul],
SM = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SUZ = torch.zeros([Ngrid, self.nmul],
SUZ = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SLZ = torch.zeros([Ngrid, self.nmul],
SLZ = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001

Expand All @@ -122,10 +123,10 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)

Nstep, Ngrid = P.size()
n_steps, n_grid = P.size()

# Apply correction factor to precipitation
# P = parPCORR.repeat(Nstep, 1) * P
# P = parPCORR.repeat(n_steps, 1) * P

# Initialize time series of model variables in shape [time, basins, nmul].
Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
Expand Down Expand Up @@ -153,14 +154,14 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
# as static in some basins.)
if len(self.dy_params) > 0:
params_dict_raw_dy = dict()
pmat = torch.ones([Ngrid, 1]) * self.dy_drop
pmat = torch.ones([n_grid, 1]) * self.dy_drop
for i, key in enumerate(self.dy_params):
drmask = torch.bernoulli(pmat).detach_().to(self.device)
dynPar = params_dict_raw[key]
staPar = params_dict_raw[key][self.static_idx, :, :].unsqueeze(0).repeat([dynPar.shape[0], 1, 1])
params_dict_raw_dy[key] = dynPar * (1 - drmask) + staPar * drmask

for t in range(Nstep):
for t in range(n_steps):
# Get dynamic parameter values per timestep.
for key in self.dy_params:
params_dict[key] = params_dict_raw_dy[key][self.warm_up + t, :, :]
Expand Down Expand Up @@ -249,7 +250,7 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
# Routing for all components or just the average.
if comprout:
# All components; reshape to [time, gages * num models]
Qsim = Qsimmu.view(Nstep, Ngrid * self.nmul)
Qsim = Qsimmu.view(n_steps, n_grid * self.nmul)
else:
# Average, then do routing.
Qsim = Qsimavg
Expand All @@ -263,8 +264,8 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
param=routing_parameters[:, 1],
bounds=self.conv_routing_hydro_model_bound[1]
)
rout_a = temp_a.repeat(Nstep, 1).unsqueeze(-1)
rout_b = temp_b.repeat(Nstep, 1).unsqueeze(-1)
rout_a = temp_a.repeat(n_steps, 1).unsqueeze(-1)
rout_b = temp_b.repeat(n_steps, 1).unsqueeze(-1)

UH = UH_gamma(rout_a, rout_b, lenF=15) # lenF: folter
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
Expand All @@ -281,7 +282,7 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,

if comprout:
# Qs is now shape [time, [gages*num models], vars]
Qstemp = Qsrout.view(Nstep, Ngrid, self.nmul)
Qstemp = Qsrout.view(n_steps, n_grid, self.nmul)
if muwts is None:
Qs = Qstemp.mean(-1, keepdim=True)
else:
Expand Down
68 changes: 33 additions & 35 deletions src/hydroDL2/models/hbv/hbv_1_1p.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hydroDL2.core.calc.uh_routing import UH_conv, UH_gamma



class HBVCapillary(torch.nn.Module):
"""Multi-component Pytorch HBV1.1p model with capillary rise modification
and option to run without warmup.
Expand All @@ -14,7 +15,7 @@ class HBVCapillary(torch.nn.Module):
which runs the HBV-light hydrological model (Seibert, 2005).
"""
def __init__(self, config=None, device=None):
super(HBVCapillary, self).__init__()
super().__init__()
self.config = config
self.initialize = False
self.warm_up = 0
Expand All @@ -28,22 +29,22 @@ def __init__(self, config=None, device=None):
self.comprout = False
self.nearzero = 1e-5
self.nmul = 1
self.parameter_bounds = dict(
parBETA=[1.0, 6.0],
parFC=[50, 1000],
parK0=[0.05, 0.9],
parK1=[0.01, 0.5],
parK2=[0.001, 0.2],
parLP=[0.2, 1],
parPERC=[0, 10],
parUZL=[0, 100],
parTT=[-2.5, 2.5],
parCFMAX=[0.5, 10],
parCFR=[0, 0.1],
parCWH=[0, 0.2],
parBETAET=[0.3, 5],
parC=[0, 1]
)
self.parameter_bounds = {
parBETA: [1.0, 6.0],
parFC: [50, 1000],
parK0: [0.05, 0.9],
parK1: [0.01, 0.5],
parK2: [0.001, 0.2],
parLP: [0.2, 1],
parPERC: [0, 10],
parUZL: [0, 100],
parTT: [-2.5, 2.5],
parCFMAX: [0.5, 10],
parCFR: [0, 0.1],
parCWH: [0, 0.2]
parBETAET: [0.3, 5],
parC: [0, 1]
}
self.conv_routing_hydro_model_bound = [
[0, 2.9], # routing parameter a
[0, 6.5] # routing parameter b
Expand All @@ -63,13 +64,10 @@ def __init__(self, config=None, device=None):
self.routing = config['phy_model']['routing']
self.nearzero = config['phy_model']['nearzero']
self.nmul = config['nmul']

if 'parBETAET' in config['phy_model']['dy_params']['HBV_1_1p']:
self.parameter_bounds['parBETAET'] = [0.3, 5]

def forward(self, x, parameters, routing_parameters=None, muwts=None,
comprout=False):
"""Forward pass for HBV1.1p"""
"""Forward pass for HBV1.1p."""
# Initialization
if not self.warm_up_states:
# No state warm up - run the full model for warm_up days.
Expand Down Expand Up @@ -98,20 +96,20 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
)
else:
# Without warm-up, initialize state variables with zeros.
Ngrid = x.shape[1]
SNOWPACK = torch.zeros([Ngrid, self.nmul],
n_grid = x.shape[1]
SNOWPACK = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
MELTWATER = torch.zeros([Ngrid, self.nmul],
MELTWATER = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SM = torch.zeros([Ngrid, self.nmul],
SM = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SUZ = torch.zeros([Ngrid, self.nmul],
SUZ = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001
SLZ = torch.zeros([Ngrid, self.nmul],
SLZ = torch.zeros([n_grid, self.nmul],
dtype=torch.float32,
device=self.device) + 0.001

Expand All @@ -133,10 +131,10 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
Tm = T.unsqueeze(2).repeat(1, 1, self.nmul)
PETm = PET.unsqueeze(-1).repeat(1, 1, self.nmul)

Nstep, Ngrid = P.size()
n_steps, n_grid = P.size()

# Apply correction factor to precipitation
# P = parPCORR.repeat(Nstep, 1) * P
# P = parPCORR.repeat(n_steps, 1) * P

# Initialize time series of model variables in shape [time, basins, nmul].
Qsimmu = torch.zeros(Pm.size(), dtype=torch.float32, device=self.device) + 0.001
Expand Down Expand Up @@ -165,14 +163,14 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
# as static in some basins.)
if len(self.dy_params) > 0:
params_dict_raw_dy = dict()
pmat = torch.ones([Ngrid, 1]) * self.dy_drop
pmat = torch.ones([n_grid, 1]) * self.dy_drop
for i, key in enumerate(self.dy_params):
drmask = torch.bernoulli(pmat).detach_().to(self.device)
dynPar = params_dict_raw[key]
staPar = params_dict_raw[key][self.static_idx, :, :].unsqueeze(0).repeat([dynPar.shape[0], 1, 1])
params_dict_raw_dy[key] = dynPar * (1 - drmask) + staPar * drmask

for t in range(Nstep):
for t in range(n_steps):
# Get dynamic parameter values per timestep.
for key in self.dy_params:
params_dict[key] = params_dict_raw_dy[key][self.warm_up + t, :, :]
Expand Down Expand Up @@ -267,7 +265,7 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
# Routing for all components or just the average.
if comprout:
# All components; reshape to [time, gages * num models]
Qsim = Qsimmu.view(Nstep, Ngrid * self.nmul)
Qsim = Qsimmu.view(n_steps, n_grid * self.nmul)
else:
# Average, then do routing.
Qsim = Qsimavg
Expand All @@ -281,8 +279,8 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,
param=routing_parameters[:, 1],
bounds=self.conv_routing_hydro_model_bound[1]
)
rout_a = temp_a.repeat(Nstep, 1).unsqueeze(-1)
rout_b = temp_b.repeat(Nstep, 1).unsqueeze(-1)
rout_a = temp_a.repeat(n_steps, 1).unsqueeze(-1)
rout_b = temp_b.repeat(n_steps, 1).unsqueeze(-1)

UH = UH_gamma(rout_a, rout_b, lenF=15) # lenF: folter
rf = torch.unsqueeze(Qsim, -1).permute([1, 2, 0]) # [gages,vars,time]
Expand All @@ -299,7 +297,7 @@ def forward(self, x, parameters, routing_parameters=None, muwts=None,

if comprout:
# Qs is now shape [time, [gages*num models], vars]
Qstemp = Qsrout.view(Nstep, Ngrid, self.nmul)
Qstemp = Qsrout.view(n_steps, n_grid, self.nmul)
if muwts is None:
Qs = Qstemp.mean(-1, keepdim=True)
else:
Expand Down
Loading

0 comments on commit c14de9b

Please sign in to comment.