diff --git a/README.md b/README.md index 57aa454..7c7e583 100644 --- a/README.md +++ b/README.md @@ -23,62 +23,3 @@ AAM integrated into `chython` package and available as reaction object method. S r.reset_mapping() print(format(r, 'm')) >> [C:2]([C:4](=[CH2:5])[CH:6]=[CH2:7])(=[O:3])[OH:1].[CH2:8]=[CH:9][C:10]#[N:11]>>[O:3]=[C:2]([OH:1])[C:4]=1[CH2:5][CH:9]([C:10]#[N:11])[CH2:8][CH2:7][CH:6]=1 - - -Pretrained model ----------------- - -**To load pretrained model use:** - - from chytorch.zoo.rxnmap import Model - model = Model.pretrained() - -**To prepare data-loader use:** - - from chython import SMILESRead - - data = [] - for r in SMILESRead('data.smi'): - r.canonicalize() # fix aromaticity and functional groups - data.append(r.pack()) # store in compressed format - - dl = model.prepare_dataloader(data, batch_size=20) - -**To get embeddings use:** - - for b in dl: - e = model(b) - -Note: embeddings contain: `cls embedding, [unusable molecular embedding, list of atoms embeddings] * n`. -Where n is the number of molecules in reaction equation. - -To extract aggregated embedding, use cls embedding `x = e[:, 0]`. - -To extract atoms-only embeddings, use masking: -* `x = e[b[3] > 1]` - for all atoms -* `x = e[b[3] == 2]` - for reactants only -* `x = e[b[3] == 3]` - for products only - -**To get all-to-all tokens attention matrix:** - - for b in dl: - a = model(b, mapping_task=True) - - -Training new model ------------------- - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint - from pytorch_lightning.plugins import DDPPlugin - - callback = ModelCheckpoint(monitor='trn_loss_tot', save_weights_only=True, save_top_k=3, save_last=True, - every_n_train_steps=10000) - trainer = Trainer(gpus=-1, precision=16, max_steps=1000000, callbacks=[callback], - strategy=DDPPlugin(find_unused_parameters=False)) - - model = Model(lr_warmup=1e4, lr_period=5e5, lr_max=1e-4, lr_decrease_coef=.01, masking_rate=.15, **kwargs) - # lr_warmup=1e4, lr_period=5e5, lr_max=1e-4, lr_decrease_coef=.01 - see chytorch.optim.lr_scheduler.WarmUpCosine. - # kwargs - see chytorch.nn.ReactionEncoder. - # masking_rate - probability of token masking. - trainer.fit(model, dl) diff --git a/chytorch/zoo/rxnmap/__init__.py b/chytorch/zoo/rxnmap/__init__.py index 9e93d4b..65a1042 100644 --- a/chytorch/zoo/rxnmap/__init__.py +++ b/chytorch/zoo/rxnmap/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2021, 2022 Ramil Nugmanov +# Copyright 2021-2023 Ramil Nugmanov # This file is part of chytorch. # # chytorch is free software; you can redistribute it and/or modify @@ -16,86 +16,34 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, see . # -from functools import partial +from math import inf from pkg_resources import resource_stream -from pytorch_lightning import LightningModule -from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor -from torch import rand -from torch.nn import LazyLinear, Parameter -from torch.nn.functional import cross_entropy -from torch.optim import AdamW, Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader -from typing import Callable, Iterator -from ...nn import ReactionEncoder -from ...optim.lr_scheduler import WarmUpCosine -from ...utils.data import ReactionDataset, collate_reactions +from torch import load, zeros_like, float as t_float +from chytorch.nn import ReactionEncoder +from chytorch.utils.data import ReactionEncoderDataset, collate_encoded_reactions -class Model(LightningModule): - def __init__(self, *, masking_rate=.15, - lr_scheduler: Callable[[Optimizer], _LRScheduler] = None, - optimizer: Callable[[Iterator[Parameter]], Optimizer] = None, **kwargs): - super().__init__() - self.encoder = ReactionEncoder(**kwargs) - self.mlma = LazyLinear(118) - self.mlmn = LazyLinear(self.encoder.molecule_encoder.centrality_encoder.num_embeddings - 2) - - if lr_scheduler is None: - lr_scheduler = partial(WarmUpCosine, decrease_coef=.01, warmup=int(1e4), period=int(5e5)) - if optimizer is None: - optimizer = partial(AdamW, lr=1e-4) - - self.lr_scheduler = lr_scheduler - self.optimizer = optimizer - self.masking_rate = masking_rate - self.save_hyperparameters(kwargs) - - @classmethod - def pretrained(cls, **kwargs): - model = cls.load_from_checkpoint(resource_stream(__package__, 'weights.pt'), map_location='cpu', **kwargs) - model.eval() - return model - - def prepare_dataloader(self, reactions, **kwargs): - """ - Prepare dataloader for training. - - :param reactions: chython packed reactions list. - """ - ds = ReactionDataset(reactions, distance_cutoff=self.encoder.max_distance, unpack=True) - return DataLoader(ds, collate_fn=collate_reactions, **kwargs) - def forward(self, batch, *, mapping_task=False): - if mapping_task: - return self.encoder(batch, need_embedding=False, need_weights=True) - return self.encoder(batch) - - def training_step(self, batch, batch_idx): - a, n, d, r = batch - m = r > 1 # atoms only - ma = a.masked_fill((rand(a.shape, device=a.device) < self.masking_rate) & m, 2) - mn = n.masked_fill((rand(n.shape, device=n.device) < self.masking_rate) & m, 1) - - x = self.encoder((ma, mn, d, r))[m] # atoms only embedding - atoms = self.mlma(x) - neighbors = self.mlmn(x) - - l1 = cross_entropy(atoms, a[m].long() - 3) - l2 = cross_entropy(neighbors, n[m].long() - 2) - self.log('trn_loss_mlm_a', l1.item(), sync_dist=True) - self.log('trn_loss_mlm_n', l2.item(), sync_dist=True) - self.log('trn_loss_tot', l1.item() + l2.item(), sync_dist=True) - return l1 + l2 - - def configure_callbacks(self): - return [ModelCheckpoint(save_weights_only=True, save_last=True, every_n_train_steps=10000), - LearningRateMonitor(logging_interval='step')] - - def configure_optimizers(self): - o = self.optimizer(self.parameters()) - s = self.lr_scheduler(o) - return [o], [{'scheduler': s, 'interval': 'step', 'name': 'lr_scheduler'}] +class Model(ReactionEncoder): + def __init__(self): + super().__init__() + self.load_state_dict(load(resource_stream(__package__, 'weights.pt'))) + self.eval() + + def forward(self, reaction): + dev = self.role_encoder.weight.device + atoms, neighbors, distances, roles = collate_encoded_reactions([ReactionEncoderDataset([reaction])[0]]).to(dev) + n = atoms.size(1) + d_mask = zeros_like(roles, dtype=t_float).masked_fill_(roles == 0, -inf).view(-1, 1, 1, n) + d_mask = d_mask.expand(-1, self.nhead, n, -1).flatten(end_dim=1) + + x = self.molecule_encoder((atoms, neighbors, distances)) * (roles > 1).unsqueeze_(-1) + x = x + self.role_encoder(roles) + + for lr in self.layers[:-1]: # noqa + x, _ = lr(x, d_mask) + _, a = self.layers[-1](x, d_mask, need_embedding=False, need_weights=True) + return a[0] __all__ = ['Model'] diff --git a/chytorch/zoo/rxnmap/weights.pt b/chytorch/zoo/rxnmap/weights.pt index bf1c6fb..7ea22c5 100644 Binary files a/chytorch/zoo/rxnmap/weights.pt and b/chytorch/zoo/rxnmap/weights.pt differ diff --git a/setup.py b/setup.py index 42e7419..773878c 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# Copyright 2022 Ramil Nugmanov +# Copyright 2022, 2023 Ramil Nugmanov # This file is part of chytorch. # # chytorch is free software; you can redistribute it and/or modify @@ -23,14 +23,14 @@ setup( name='chytorch-rxnmap', - version='1.3', + version='1.4', packages=find_namespace_packages(include=('chytorch.*',)), url='https://github.com/chython/chytorch-rxnmap', license='LGPLv3', author='Dr. Ramil Nugmanov', author_email='nougmanoff@protonmail.com', python_requires='>=3.8', - install_requires=['pytorch-lightning>=1.5.6', 'chytorch>=1.13,<2.0'], + install_requires=['chytorch>=1.42,<2.0'], zip_safe=False, long_description=(Path(__file__).parent / 'README.md').read_text('utf8'), long_description_content_type='text/markdown',