Module imodelsx.kan.kan_modules
Code for KanLinearModule and KANModule taken from https://github.com/Blealtan/efficient-kan/tree/master Original implementation at https://github.com/KindXiaoming/pykan Original paper: "KAN: Kolmogorov-Arnold Networks" https://arxiv.org/abs/2404.19756
Expand source code
'''Code for KanLinearModule and KANModule taken from https://github.com/Blealtan/efficient-kan/tree/master
Original implementation at https://github.com/KindXiaoming/pykan
Original paper: "KAN: Kolmogorov-Arnold Networks" https://arxiv.org/abs/2404.19756
'''
import torch
import torch.nn.functional as F
import math
from typing import List
class KANLinearModule(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinearModule, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(
self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1,
self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order: -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(
self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1:] - x)
/ (grid[:, k + 1:] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(
splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] +
2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + \
(1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1,
device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1,
device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(
self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KANModule(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANModule, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinearModule(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(
regularize_activation, regularize_entropy)
for layer in self.layers
)
class KANGAMModule(torch.nn.Module):
'''Learn a KAN model on each individual input feature
'''
def __init__(self, num_features, layers_hidden: List[int], n_classes, **kwargs):
super(KANGAMModule, self).__init__()
self.models = torch.nn.ModuleList([
KANModule(
layers_hidden=[1] + layers_hidden + [1],
**kwargs)
for _ in range(num_features)
])
self.linear = torch.nn.Linear(num_features, n_classes)
def forward(self, x: torch.Tensor, update_grid=False):
features = torch.stack(
[model(x[:, i:i + 1], update_grid)
for i, model in enumerate(self.models)],
dim=1)
features = features.view(x.size(0), -1)
return self.linear(features)
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0, regularize_ridge=1.0):
return sum(
layer.regularization_loss(
regularize_activation, regularize_entropy)
for model in self.models
for layer in model.layers
) + regularize_ridge * self.linear.weight.norm(p=2)
Classes
class KANGAMModule (num_features, layers_hidden: List[int], n_classes, **kwargs)
-
Learn a KAN model on each individual input feature
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class KANGAMModule(torch.nn.Module): '''Learn a KAN model on each individual input feature ''' def __init__(self, num_features, layers_hidden: List[int], n_classes, **kwargs): super(KANGAMModule, self).__init__() self.models = torch.nn.ModuleList([ KANModule( layers_hidden=[1] + layers_hidden + [1], **kwargs) for _ in range(num_features) ]) self.linear = torch.nn.Linear(num_features, n_classes) def forward(self, x: torch.Tensor, update_grid=False): features = torch.stack( [model(x[:, i:i + 1], update_grid) for i, model in enumerate(self.models)], dim=1) features = features.view(x.size(0), -1) return self.linear(features) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0, regularize_ridge=1.0): return sum( layer.regularization_loss( regularize_activation, regularize_entropy) for model in self.models for layer in model.layers ) + regularize_ridge * self.linear.weight.norm(p=2)
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x: torch.Tensor, update_grid=False) ‑> Callable[..., Any]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x: torch.Tensor, update_grid=False): features = torch.stack( [model(x[:, i:i + 1], update_grid) for i, model in enumerate(self.models)], dim=1) features = features.view(x.size(0), -1) return self.linear(features)
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0, regularize_ridge=1.0)
-
Expand source code
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0, regularize_ridge=1.0): return sum( layer.regularization_loss( regularize_activation, regularize_entropy) for model in self.models for layer in model.layers ) + regularize_ridge * self.linear.weight.norm(p=2)
class KANLinearModule (in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.modules.activation.SiLU, grid_eps=0.02, grid_range=[-1, 1])
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class KANLinearModule(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinearModule, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_( self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order: -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_( self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm( splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + \ (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_( self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy )
Ancestors
- torch.nn.modules.module.Module
Instance variables
var scaled_spline_weight
-
Expand source code
@property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 )
Methods
def b_splines(self, x: torch.Tensor)
-
Compute the B-spline bases for the given input tensor.
Args
x
:torch.Tensor
- Input tensor of shape (batch_size, in_features).
Returns
torch.Tensor
- B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
Expand source code
def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor)
-
Compute the coefficients of the curve that interpolates the given points.
Args
x
:torch.Tensor
- Input tensor of shape (batch_size, in_features).
y
:torch.Tensor
- Output tensor of shape (batch_size, in_features, out_features).
Returns
torch.Tensor
- Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
Expand source code
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous()
def forward(self, x: torch.Tensor) ‑> Callable[..., Any]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0)
-
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization.
Expand source code
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy )
def reset_parameters(self)
-
Expand source code
def reset_parameters(self): torch.nn.init.kaiming_uniform_( self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order: -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_( self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def update_grid(self, x: torch.Tensor, margin=0.01)
-
Expand source code
@torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm( splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + \ (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_( self.curve2coeff(x, unreduced_spline_output))
class KANModule (layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.modules.activation.SiLU, grid_eps=0.02, grid_range=[-1, 1])
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class KANModule(torch.nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANModule, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = torch.nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinearModule( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss( regularize_activation, regularize_entropy) for layer in self.layers )
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self, x: torch.Tensor, update_grid=False) ‑> Callable[..., Any]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0)
-
Expand source code
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss( regularize_activation, regularize_entropy) for layer in self.layers )