Source code for autogllight.nas.space.graces

from . import BaseSpace
from .graces_space import *


[docs]class GracesSpace(BaseSpace): def __init__(self, input_dim, output_dim, num_nodes, mol, virtual, criterion, args): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_nodes = num_nodes self.mol = mol self.virtual = virtual self.criterion = criterion self.args = args self.use_forward = True
[docs] def build_graph(self): self.supernet0 = GEncoder( criterion=self.criterion, in_dim=self.input_dim, out_dim=self.output_dim, hidden_size=self.args.graph_dim, num_layers=2, dropout=0.5, epsilon=self.args.epsilon, args=self.args, with_conv_linear=self.args.with_conv_linear, num_nodes=self.num_nodes, mol=self.mol, virtual=self.virtual, ) self.supernet = Network( criterion=self.criterion, in_dim=self.input_dim, out_dim=self.output_dim, hidden_size=self.args.hidden_size, num_layers=self.args.num_layers, dropout=self.args.dropout, epsilon=self.args.epsilon, args=self.args, with_conv_linear=self.args.with_conv_linear, num_nodes=self.num_nodes, mol=self.mol, virtual=self.virtual, ) num_na_ops = len(NA_PRIMITIVES) num_pool_ops = len(POOL_PRIMITIVES) self.ag = AG(args=self.args, num_op=num_na_ops, num_pool=num_pool_ops) self.explore_num = 0
[docs] def forward(self, data): if not self.use_forward: return self.prediction # mhssl pred0, graph_emb, sslout = self.supernet0(data, mode="mixed") # ag graph_alpha, cosloss = self.ag(graph_emb) # final supernet pred, _ = self.supernet(data, mode="mads", graph_alpha=graph_alpha) self.current_pred = pred return pred0, pred, cosloss, sslout
def keep_prediction(self): self.prediction = self.current_pred
[docs] def parse_model(self, selection): self.use_forward = False return self.wrap()