Move RNN to layers.py and make it stateless.#97
Move RNN to layers.py and make it stateless.#97aterzis-google wants to merge 5 commits intogoogle:masterfrom
Conversation
objax/nn/layers.py
Outdated
| self.avg.value += (self.avg.value - x) * (self.momentum - 1) | ||
| return self.avg.value | ||
|
|
||
| class RNN(Module): |
There was a problem hiding this comment.
I think the name RNN is too generic.
Pretty much any type of recurrent block (LSTM, GRU, ....) could be called RNN.
Is there some better way to call it?
There was a problem hiding this comment.
Also RNN refers to the architecture, not to the cell. Here's what TF/Keras does https://www.tensorflow.org/api_docs/python/tf/compat/v1/nn/rnn_cell/RNNCell
Not sure what PyTorch does.
There was a problem hiding this comment.
This is a specific RNN architecture that operates across time (so not a cell). I would call this something like SimpleRNN; and make sure it replicates keras' SimpleRNN functionality with default arguments:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNN
RNN you could reserve as an object that takes an RNNCell and performs a scan across time.
There was a problem hiding this comment.
Changed the name to SimpleRNN
objax/nn/layers.py
Outdated
|
|
||
| self.output_layer = Linear(self.nstate, self.num_outputs) | ||
|
|
||
| def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: |
There was a problem hiding this comment.
suggest adding a get_initial_state method and optional initial_state argument here
There was a problem hiding this comment.
I added an optional initial_state argument to the call() method.
Can you clarify what the get_initial_state() method would do, considering that the state is initialized during every call() (unless explicitly passed in through the optional argument)?
There was a problem hiding this comment.
there are two reasons to have a get_initial_state: One, the caller wants to know if this layer is recurrent, without checking for some general instance type. Two, the caller wants to know the shapes etc of the state, without running __call__. This is useful for many reasons, like creating buffers for storing state.
There was a problem hiding this comment.
Just to clarify, does get_init_state really act like a create_init_state? Or is there an init_state stored inside the instance?
There was a problem hiding this comment.
no; it's a purely functional thing that returns some arrays.
There was a problem hiding this comment.
As far as I understood from some of the Keras code, get_initial_state simply returns zero array of appropriate shape (ex: https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/layers/recurrent.py#L1948 )
It's still not very clear how useful it is.
Could you point us to some example of how it's actually used (either in Tensorflow or any other framework)?
To know shape of the state it would be better to just call rnn_cell_layer.nstate or maybe have helper method get_state_shape.
Using get_initial_state as a way to determine whether layer is RNN seems like a little weird. I don't see how getattr(layer, 'get_initial_state') is better than isinstance(layer, RNNCell). If there is a need to determine whether layer is RNN cell, I think it's better just to make all RNN cells to inherit from some base class and do isinstance check.
objax/nn/layers.py
Outdated
| only_return_final: return only the last output if ``True``, or all output otherwise.` | ||
|
|
||
| Returns: | ||
| Output tensor with dimensions ``N * batch_size, vocabulary_size``. |
There was a problem hiding this comment.
is vocabulary_size the right terminology for RNNs? perhaps you mean nout here?
Also why is batch_size included here? I thought you don't consider batch_size in these layers?
There was a problem hiding this comment.
Changed vocabulary_size -> nout
I include batch_size because we can process a batch of input data.
There was a problem hiding this comment.
@david-berthelot do other layers "know" about batch dimensions? does this one need to?
There was a problem hiding this comment.
(from david on another PR: no, layers don't know about batch dimensions, so this one shouldn't either. instead, add a unit test with this object and Vectorized)
jli05
left a comment
There was a problem hiding this comment.
What the RNN stands out by in this lib for me is the code readability and simplicity. Any person can easily extend it.
| jn.dot(x, self.w_xh.value) | ||
| + jn.dot(state, self.w_hh.value) |
There was a problem hiding this comment.
num_inputs could be zero. -- Essentially empty inputs but internal states continue to evolve along time.
Not sure if we shall use two weight matrices or one to act on concatenated [h, x].
There was a problem hiding this comment.
Typically it's more efficient to act on one concatenated [h, x], but depends on the system and sizes. At some point you can make this an __init__ mode parameter like Keras does. For now I'd suggest using the concatenated format.
There was a problem hiding this comment.
Another nit, use x.dot(y) rather than jn.dot(x, y) since we might as well take advantage of object oriented APIs.
| + jn.dot(state, self.w_hh.value) | ||
| + self.b_h.value | ||
| ) | ||
| y = self.output_layer(state) |
There was a problem hiding this comment.
Do we need output_layer or can we directly return internal states h and let user do further transform on that?
There was a problem hiding this comment.
I opted for having an output_layer
There was a problem hiding this comment.
Question why: this is something the user can do themselves after, right? So is there any purpose to add an output_layer?
There was a problem hiding this comment.
I would drop the output layer, that's forcing a decision on the user about what type of output they'd want.
initial_state to the constructor, and output RNN state when call() returns.
|
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
1 similar comment
|
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
619be74 to
efcb605
Compare
|
I am ooo and will return next week.
…On Wed, Oct 28, 2020, 5:44 PM Andreas Terzis (Google) < ***@***.***> wrote:
@aterzis-google <https://github.com/aterzis-google> requested your review
on: #97 <#97> Move RNN to layers.py
and make it stateless..
—
You are receiving this because your review was requested.
Reply to this email directly, view it on GitHub
<#97 (comment)>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AANWFG4JVHW6MNO2S5U5WFTSNC3FJANCNFSM4SNTPVWA>
.
|
|
|
||
| if only_return_final: | ||
| return y, state | ||
| else: |
| if only_return_final: | ||
| return y, state | ||
| else: | ||
| return jn.concatenate(outputs, axis=0), state |
There was a problem hiding this comment.
Should it be jn.stack?
| jn.dot(x, self.w_xh.value) | ||
| + jn.dot(state, self.w_hh.value) |
There was a problem hiding this comment.
Another nit, use x.dot(y) rather than jn.dot(x, y) since we might as well take advantage of object oriented APIs.
| def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, | ||
| only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]: |
There was a problem hiding this comment.
One argument per line if they don't all fit on one line.
| def loss(x, label): # sum(label * log(softmax(logit))) | ||
| logit = model(x) | ||
| return objax.functional.loss.cross_entropy_logits(logit, label).mean() | ||
| logits, _ = model(x) |
There was a problem hiding this comment.
logits = model(x)[0]
| outputs.append(vocab[y]) | ||
| for _ in range(num_predicts): # Predict num_predicts steps | ||
| Y = model(get_input()) | ||
| Y, _ = model(get_input()) |
There was a problem hiding this comment.
- Uppercase are for global constants, use lower case identifiers for variables please.
- Also rather than doing two assigns, the better way is to just assign what you use.
Y = model(get_input())[0]
| <<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py | ||
| from objax.nn import SimpleRNN | ||
| ======= | ||
| from objax.nn import RNN | ||
| >>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py |
There was a problem hiding this comment.
Your commit contains an unresolved merge.
No description provided.