The scripts and input files that accompany this demo can be found in the
demos/public
directory of the Rosetta weekly releases.
KEYWORDS: CORE_CONCEPTS DOCKING
As toolchains evolve, it is useful to compare a set of unit tests against their libraries to make sure we can reproduce Rosetta's logic. These unit tests were created by calling core/kinematics/*
functions and objects, but operate in Numpy syntax.
This uses flax's numpy-like syntax for the purpose of unit testing. The code should still work if you just use import numpy as np
and leave out some jax/flax bits.
import jax
import jax.numpy as jnp
import jax.numpy as np
import flax.linen as nn
class BatchDotProduct(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return np.sum(x*y, axis=-1, keepdims=True)
pass
@classmethod
def debug(cls):
a = np.array([[1,2,3], [3,4,5]])
b = np.array([[1,2,3], [1,2,3]])
o = np.array([ [14.], [26.], ])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), a, b )
output = layer.apply( variables, a, b )
print( output )
assert jnp.allclose(
output,
o
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class Angle(nn.Module):
@nn.compact
def __call__(self,
p0: jnp.ndarray,
p1: jnp.ndarray,
p2: jnp.ndarray,
) -> jnp.ndarray:
a = p1 - p0
b = p1 - p2
term1 = np.linalg.norm(
np.cross(a,b),
axis=-1,
keepdims=True
)
term2 = BatchDotProduct()(a,b)
return np.abs( np.arctan2(term2, term1) )
@classmethod
def debug(cls):
p0 = jnp.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
p1 = jnp.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
p2 = jnp.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), p0, p1, p2 )
output = layer.apply( variables, p0, p1, p2 )
print( output )
assert jnp.allclose(
output,
jnp.array([[0.44907826],
[0.5157146 ],
[0.4016324 ],
[0.36969098]], dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class Dihedral(nn.Module):
@nn.compact
def __call__(self,
p0: jnp.ndarray,
p1: jnp.ndarray,
p2: jnp.ndarray,
p3: jnp.ndarray
) -> jnp.ndarray:
# https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python
b0 = -1.0*(p1 - p0)
b1 = p2 - p1
b2 = p3 - p2
# normalize b1 so that it does not influence magnitude of vector
# rejections that come next
b1 /= np.linalg.norm(b1, axis=-1, keepdims=True)
# vector rejections
# v = projection of b0 onto plane perpendicular to b1
# = b0 minus component that aligns with b1
# w = projection of b2 onto plane perpendicular to b1
# = b2 minus component that aligns with b1
v = b0 - BatchDotProduct()(b0, b1)*b1
w = b2 - BatchDotProduct()(b2, b1)*b1
# angle between v and w in a plane is the torsion angle
# v and w may not be normalized but that's fine since tan is y/x
x = BatchDotProduct()(v, w)
y = BatchDotProduct()(np.cross(b1, v), w)
return np.arctan2(y, x)
@classmethod
def debug(cls):
p0 = jnp.array([
[0.3, -0.2, 0.1],
[2.3, 0.2, 10],
[3.4, 1.2, -0.2],
[3.4, 1.2, 0.2]
])
p1 = jnp.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
p2 = jnp.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
p3 = jnp.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), p0, p1, p2, p3 )
output = layer.apply( variables, p0, p1, p2, p3 )
print( output )
assert jnp.allclose(
output,
jnp.array([[ 1.8286608],
[-0.503368 ],
[ 1.4385247],
[ 1.8120883]],
dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class ComputeInputStub(nn.Module):
@nn.compact
def __call__(self, ggp: jnp.ndarray, gp: jnp.ndarray, p: jnp.ndarray ):
"""
Computes the stub used to create an atom from its ancestors. Mirrors Rosetta methodology to a degree.
Parameters
----------
ggp : jnp.ndarray
XYZ coords of shape (num_nodes, 3) of the great-grandparent atom.
gp : jnp.ndarray
XYZ coords of shape (num_nodes, 3) of the grandparent atom.
p : jnp.ndarray
XYZ coords of shape (num_nodes, 3) of the parent atom.
Returns
-------
M : np.ndarray
Rotation Matrix of shape (num_nodes, 3, 3).
v : np.ndarray
Rotation Matrix of shape (num_nodes, 3).
"""
a = p
b = gp
c = ggp
center = p
e1 = a-b
e1 = e1 / np.linalg.norm( e1, axis=-1, keepdims=True )
e3 = np.cross( e1, c-b )
e3 = e3 / np.linalg.norm( e3, axis=-1, keepdims=True )
e2 = np.cross( e3, e1 )
e1 = np.expand_dims(e1, -1)
e2 = np.expand_dims(e2, -1)
e3 = np.expand_dims(e3, -1)
return np.concatenate( [e1,e2,e3], axis=-1 ), p
#@staticmethod
@classmethod
def debug(cls):
# first, santity check that we know how to calc distances
a = jnp.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2]
])
b = jnp.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2]
])
diff = a-b
dist = np.linalg.norm( diff, axis=-1, keepdims=True )
assert dist[0][0] > 0.99 # 1
assert dist[0][0] < 1.01 # 1
assert dist[1][0] > 4.99 # 5
assert dist[1][0] < 5.01 # 5
assert dist[2][0] > 6.23 # 6.23298
assert dist[2][0] < 6.24 # 6.23298
print( dist )
ggp = a
gp = b
p = jnp.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2]
])
layer = ComputeInputStub()
variables = layer.init( jax.random.PRNGKey(0), ggp, gp, p )
output = layer.apply( variables, ggp, gp, p )
# v
assert jnp.allclose(
output[1],
jnp.array([[20. , 12. , 12. ],
[22. , 2. , 42. ],
[ 2.21, 2.2 , -2.2 ]], dtype='float32')
)
# M
assert jnp.allclose(
output[0],
jnp.array([[[ 0.78933704, 0.3803963 , -0.48191875],
[ 0.43413538, 0.20921798, 0.8762159 ],
[ 0.43413538, -0.9008476 , 0. ]],
[[ 0.46524212, 0.26373896, -0.8449802 ],
[ 0.04652421, 0.945977 , 0.32087854],
[ 0.88396 , -0.18859825, 0.42783806]],
[[ 0.04913414, 0.06628565, -0.99659026],
[-0.44667345, -0.89099705, -0.0812844 ],
[-0.8933469 , 0.44914424, -0.01417033]]], dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class XRotMatRadians(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
#x = np.expand_dims(x, -2)
angle_sin = np.sin( x )
angle_cos = np.cos( x )
ones = np.ones_like( x )
zeros = np.zeros_like( x )
angle_sin_neg = -1 * angle_sin
# doing columns instead of rows, even though the code reads like rows
col_x = np.concatenate( [ones,zeros,zeros], axis=-1 )
#print( col_x.shape )
col_y = np.concatenate( [zeros,angle_cos,angle_sin], axis=-1 )
col_z = np.concatenate( [zeros,angle_sin_neg,angle_cos], axis=-1 )
#col_x = np.expand_dims(col_x, -1)
#col_y = np.expand_dims(col_y, -1)
#col_z = np.expand_dims(col_z, -1)
return np.stack( [col_x,col_y,col_z], axis=-1 )
@classmethod
def debug(cls):
a = jnp.array([
[ 1., ],
[ 2., ],
[ 1.7, ],
[ 0., ],
[ -2.23, ],
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), a )
output = layer.apply( variables, a )
assert jnp.allclose(
output,
jnp.array([[[ 1. , 0. , 0. ],
[ 0. , 0.5403023,-0.84147096],
[ 0. , 0.84147096, 0.5403023 ]],
[[ 1. , 0. , 0. ],
[ 0. ,-0.41614684, -0.9092974 ],
[ 0. , 0.9092974,-0.41614684]],
[[ 1. , 0. , 0. ],
[ 0. ,-0.12884454, -0.9916648 ],
[ 0. , 0.9916648,-0.12884454]],
[[ 1. , 0. , 0. ],
[ 0. , 1. ,-0. ],
[ 0. , 0. , 1. ]],
[[ 1. , 0. , 0. ],
[ 0. ,-0.61248755, 0.7904802 ],
[ 0. ,-0.7904802,-0.61248755]]],
dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class ZRotMatRadians(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
#x = np.expand_dims(x, -2)
angle_sin = np.sin( x )
angle_cos = np.cos( x )
ones = np.ones_like( x )
zeros = np.zeros_like( x )
angle_sin_neg = -1 * angle_sin
# doing columns instead of rows, even though the code reads like rows
col_x = np.concatenate( [angle_cos,angle_sin,zeros], axis=-1 )
col_y = np.concatenate( [angle_sin_neg,angle_cos,zeros], axis=-1 )
col_z = np.concatenate( [zeros,zeros,ones], axis=-1 )
#col_x = np.expand_dims(col_x, -1)
#col_y = np.expand_dims(col_y, -1)
#col_z = np.expand_dims(col_z, -1)
return np.stack( [col_x,col_y,col_z], axis=-1 )
@classmethod
def debug(cls):
a = jnp.array([
[ 1., ],
[ 2., ],
[ 1.7, ],
[ 0., ],
[ -2.23, ],
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), a )
output = layer.apply( variables, a )
assert jnp.allclose(
output,
jnp.array([[[ 0.5403023 , -0.84147096, 0. ],
[ 0.84147096, 0.5403023 , 0. ],
[ 0. , 0. , 1. ]],
[[-0.41614684, -0.9092974 , 0. ],
[ 0.9092974 , -0.41614684, 0. ],
[ 0. , 0. , 1. ]],
[[-0.12884454, -0.9916648 , 0. ],
[ 0.9916648 , -0.12884454, 0. ],
[ 0. , 0. , 1. ]],
[[ 1. , -0. , 0. ],
[ 0. , 1. , 0. ],
[ 0. , 0. , 1. ]],
[[-0.61248755, 0.7904802 , 0. ],
[-0.7904802 , -0.61248755, 0. ],
[ 0. , 0. , 1. ]]],
dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class TranslateFromStubB(nn.Module):
@nn.compact
def __call__(self, StubB: jnp.ndarray, ic_in: jnp.ndarray) -> jnp.ndarray:
col_x = StubB[:,:,0]
d = ic_in[:,2]
d = np.expand_dims(d, -1)
return d * col_x
@classmethod
def debug(cls):
# Currently tested in GetAtomXYZ.debug
print( "passed {} tests".format( cls.__name__ ) )
pass
class GetAtomXYZ(nn.Module):
@nn.compact
def __call__(self,
p: jnp.ndarray,
gp: jnp.ndarray,
ggp: jnp.ndarray,
ic_in: jnp.ndarray
) -> jnp.ndarray:
parent_M, parent_v = ComputeInputStub()(ggp,gp,p)
phi = ic_in[:,0]
phi = np.expand_dims(phi, -1)
x_phi = XRotMatRadians()(phi)
theta = ic_in[:,1]
theta = np.expand_dims(theta, -1)
z_theta = ZRotMatRadians()(theta)
StubA = np.matmul( parent_M, x_phi )
StubB = np.matmul( StubA, z_theta )
dv = TranslateFromStubB()( StubB, ic_in )
return parent_v + dv
@classmethod
def debug(cls):
ggp = jnp.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
gp = jnp.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
p = jnp.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
ic = jnp.array([
[0.3, -0.2, 0.1],
[2.3, 0.2, 10],
[3.4, 1.2, -0.2],
[3.4, 1.2, 0.2]
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), p, gp, ggp, ic )
output = layer.apply( variables, p, gp, ggp, ic )
#print( output )
assert jnp.allclose(
output,
jnp.array([[20.07297 , 12.033433 , 12.059646 ],
[24.958748 , 1.6791693, 51.54688 ],
[ 2.1709127, 2.0679247, -2.0549886],
[ 2.2489984, 2.3320887, -2.445023 ]],
dtype='float32')
)
print( "passed {} tests".format( cls.__name__ ) )
pass
class StubFromCoords(nn.Module):
@nn.compact
def __call__(self,
N: jnp.ndarray,
CA: jnp.ndarray,
C: jnp.ndarray
) -> jnp.ndarray:
#matching vocab from Stub.cc
a = N
b = CA
c = C
center = CA
e1 = a-b
e1 = e1 / np.linalg.norm( e1, axis=-1, keepdims=True )
e3 = np.cross( e1, c-b )
e3 = e3 / np.linalg.norm( e3, axis=-1, keepdims=True )
e2 = np.cross( e3, e1 )
e1 = np.expand_dims(e1, -1)
e2 = np.expand_dims(e2, -1)
e3 = np.expand_dims(e3, -1)
#foo = np.concatenate( [e1,e2,e3], axis=-1 )
#center = np.expand_dims( center, -1 )
#foo = np.concatenate( [foo,center], axis=-1 )
#print( foo.shape )
return np.concatenate( [e1,e2,e3], axis=-1 ), center
pass
@classmethod
def debug(cls):
N = jnp.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
CA = jnp.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
C = jnp.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
o0 = jnp.array([[[ 0. , 0.87621593, 0.48191875],
[ 0. , 0.48191875, -0.87621593],
[-1. , 0. , 0. ]],
[[ 0. , 0.5347976 , 0.8449802 ],
[ 0.8 , 0.5069881 , -0.32087854],
[-0.6 , 0.67598414, -0.42783806]],
[[ 0.08021849, 0.01931177, 0.99659014],
[-0.99470925, -0.06281925, 0.08128439],
[ 0.06417479, -0.99783796, 0.01417033]],
[[ 0.0800128 , 0.02168863, 0.9965579 ],
[-0.9921587 , -0.09456854, 0.08171775],
[ 0.09601536, -0.99528205, 0.01395187]]], dtype='float32')
o1 = jnp.array([[ 0. , 1. , 1. ],
[ 2. , 0. , 4. ],
[ 2.1, 3.2, -0.2],
[ 2.1, 3.2, -0.3]], dtype='float32')
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), N, CA, C )
output = layer.apply( variables, N, CA, C )
#print( output )
assert jnp.allclose( output[0], o0 )
assert jnp.allclose( output[1], o1 )
print( "passed {} tests".format( cls.__name__ ) )
pass
class HTValsFromCoords(nn.Module):
@nn.compact
def __call__(self,
Ni: jnp.ndarray,
CAi: jnp.ndarray,
Ci: jnp.ndarray,
Nf: jnp.ndarray,
CAf: jnp.ndarray,
Cf: jnp.ndarray
) -> jnp.ndarray:
Mi, Vi = StubFromCoords()(Ni, CAi, Ci)
Mf, Vf = StubFromCoords()(Nf, CAf, Cf)
Mi_t = np.transpose( Mi, axes=[0,2,1] )
dV = Vf - Vi
dV = np.expand_dims(dV, -1)
T = np.matmul( Mi_t, dV )
R = np.matmul( Mi_t, Mf )
CAdist = np.linalg.norm(CAi-CAf, axis=-1, keepdims=True)
#print( T.shape, R.shape, CAdist.shape )
# (4, 3, 1) (4, 3, 3) (4, 1)
#exit( 0 )
Tflat = np.reshape( T, newshape=[-1, 3] )
Rflat = np.reshape( R, newshape=[-1, 9] )
return np.concatenate( [Tflat,Rflat,CAdist], axis=-1 )
@classmethod
def debug(cls):
p0 = np.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
p1 = np.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
p2 = np.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
p3 = np.array([
[0.3, -0.2, 0.1],
[2.3, 0.2, 10],
[3.4, 1.2, -0.2],
[3.4, 1.2, 0.2]
])
o = jnp.array([[-1.1000000e+01, 2.2825424e+01, 0.0000000e+00, 4.3413538e-01,
4.4217414e-01, -7.8486210e-01, -9.0084767e-01, 2.1309203e-01,
-3.7823996e-01, 2.9802322e-08, 8.7124860e-01, 4.9084231e-01,
2.5337719e+01],
[-2.1200001e+01, 3.7397324e+01, -1.9073486e-06, 4.9315664e-01,
-2.9688621e-01, 8.1771332e-01, -8.6994052e-01, -1.6830048e-01,
4.6354976e-01, -2.9802322e-08, -9.3996465e-01, -3.4127179e-01,
4.2988369e+01],
[ 8.7518364e-01, 2.0606194e+00, 3.7252903e-09, -3.9092135e-01,
7.8699952e-01, -4.7729683e-01, -9.2042404e-01, -3.3425343e-01,
2.0271687e-01, -1.8626451e-09, 5.1856184e-01, 8.5504007e-01,
2.2387719e+00],
[ 8.0892938e-01, 2.0875185e+00, 3.7252903e-09, -3.6132732e-01,
8.1903273e-01, -4.4567707e-01, -9.3243903e-01, -3.1738147e-01,
1.7270330e-01, -9.3132257e-10, 4.7796917e-01, 8.7837678e-01,
2.2387719e+00]], dtype='float32')
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), p0, p1, p2, p1, p2, p3 )
output = layer.apply( variables, p0, p1, p2, p1, p2, p3 )
#print( output )
assert jnp.allclose( output, o, atol=1e-04 )
print( "passed {} tests".format( cls.__name__ ) )
pass
class BuildHTs(nn.Module):
@nn.compact
def __call__(self,
A: jnp.ndarray,
N: jnp.ndarray,
CA: jnp.ndarray,
C: jnp.ndarray
) -> jnp.ndarray:
index_sources = A[:,0]
index_targets = A[:,1]
# from jax.lax.gather:
# The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), ...]), rather than using gather directly.
Nsource = N[index_sources,:]
CAsource = CA[index_sources,:]
Csource = C[index_sources,:]
Ntarget = N[index_targets,:]
CAtarget = CA[index_targets,:]
Ctarget = C[index_targets,:]
return HTValsFromCoords()(Nsource,CAsource,Csource,Ntarget,CAtarget,Ctarget)
@classmethod
def debug(cls):
N = np.array([
[0., 1., 0.],
[2., 4., 1.,],
[2.6, -3.0, 0.2],
[2.6, -3.0, 0.3],
])
CA = np.array([
[0., 1., 1.],
[2., 0., 4.,],
[2.1, 3.2, -0.2],
[2.1, 3.2, -0.3]
])
C = np.array([
[20., 12., 12.],
[22., 2., 42.,],
[2.21, 2.2, -2.2],
[2.21, 2.2, -2.3]
])
A = np.array([
[0, 1],
[1, 0],
[1, 2],
[3, 1]
])
Ni = np.array([
[0., 1., 0.], #0
[2., 4., 1.,], #1
[2., 4., 1.,], #1
[2.6, -3.0, 0.3], #3
])
Nf = np.array([
[2., 4., 1.,], #1
[0., 1., 0.], #0
[2.6, -3.0, 0.2], #2
[2., 4., 1.,], #1
])
CAi = np.array([
[0., 1., 1.], #0
[2., 0., 4.,], #1
[2., 0., 4.,], #1
[2.1, 3.2, -0.3] #3
])
CAf = np.array([
[2., 0., 4.,], #1
[0., 1., 1.], #0
[2.1, 3.2, -0.2], #2
[2., 0., 4.,], #1
])
Ci = np.array([
[20., 12., 12.], #0
[22., 2., 42.,], #1
[22., 2., 42.,], #1
[2.21, 2.2, -2.3] #3
])
Cf = np.array([
[22., 2., 42.,], #1
[20., 12., 12.], #0
[2.21, 2.2, -2.2], #2
[22., 2., 42.,], #1
])
layer = cls()
variables = layer.init( jax.random.PRNGKey(0), A, N, CA, C )
output = layer.apply( variables, A, N, CA, C )
layer2 = HTValsFromCoords()
var2 = layer2.init( jax.random.PRNGKey(0), Ni, CAi, Ci, Nf, CAf, Cf )
o2 = layer2.apply( var2, Ni, CAi, Ci, Nf, CAf, Cf )
assert jnp.allclose( output, o2 )
print( "passed {} tests".format( cls.__name__ ) )
pass
if __name__ == '__main__':
BatchDotProduct.debug()
Angle.debug()
Dihedral.debug()
ComputeInputStub.debug()
XRotMatRadians.debug()
ZRotMatRadians.debug()
TranslateFromStubB.debug()
GetAtomXYZ.debug()
StubFromCoords.debug()
HTValsFromCoords.debug()
BuildHTs.debug()