Skip to content

Commit

Permalink
Update docs from c29b87c
Browse files Browse the repository at this point in the history
  • Loading branch information
olivedevteam committed Dec 1, 2023
1 parent d977032 commit e2c061f
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 275 deletions.
80 changes: 65 additions & 15 deletions _modules/olive/passes/pytorch/lora.html
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ <h1>Source code for olive.passes.pytorch.lora</h1><div class="highlight"><pre>
<span class="c1"># --------------------------------------------------------------------------</span>
<span class="kn">import</span> <span class="nn">dataclasses</span>
<span class="kn">import</span> <span class="nn">logging</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">tempfile</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
Expand Down Expand Up @@ -271,6 +272,14 @@ <h1>Source code for olive.passes.pytorch.lora</h1><div class="highlight"><pre>
<span class="s2">&quot;use_ort_trainer&quot;</span><span class="p">:</span> <span class="n">PassConfigParam</span><span class="p">(</span>
<span class="n">type_</span><span class="o">=</span><span class="nb">bool</span><span class="p">,</span> <span class="n">default_value</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether or not to use ORTTrainer.&quot;</span>
<span class="p">),</span>
<span class="s2">&quot;ortmodule_onnx_opset_version&quot;</span><span class="p">:</span> <span class="n">PassConfigParam</span><span class="p">(</span>
<span class="n">type_</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
<span class="n">default_value</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="p">(</span>
<span class="s2">&quot;The opset version to use for ONNX export when using ORTTrainer. Only used if use_ort_trainer is&quot;</span>
<span class="s2">&quot; True. 16+ is required when using bfloat16 and model has operators such as Where.&quot;</span>
<span class="p">),</span>
<span class="p">),</span>
<span class="s2">&quot;lora_r&quot;</span><span class="p">:</span> <span class="n">PassConfigParam</span><span class="p">(</span><span class="n">type_</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default_value</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Lora attention dimension.&quot;</span><span class="p">),</span>
<span class="s2">&quot;lora_alpha&quot;</span><span class="p">:</span> <span class="n">PassConfigParam</span><span class="p">(</span>
<span class="n">type_</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default_value</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;The alpha parameter for Lora scaling.&quot;</span>
Expand Down Expand Up @@ -344,19 +353,58 @@ <h1>Source code for olive.passes.pytorch.lora</h1><div class="highlight"><pre>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="n">with_fixed_value</span><span class="p">:</span>
<span class="n">search_point</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config_at_search_point</span><span class="p">(</span><span class="n">search_point</span> <span class="ow">or</span> <span class="p">{})</span>
<span class="k">if</span> <span class="n">search_point</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;use_ort_trainer&quot;</span><span class="p">):</span>
<span class="k">if</span> <span class="n">search_point</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;torch_dtype&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;bfloat16&quot;</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s2">&quot;bfloat16 is not supported by onnxruntime-training yet. Please use a different torch_dtype.&quot;</span>
<span class="k">if</span> <span class="n">search_point</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;use_ort_trainer&quot;</span><span class="p">)</span> <span class="ow">and</span> <span class="n">search_point</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;training_args&quot;</span><span class="p">,</span> <span class="p">{})</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;gradient_checkpointing&quot;</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s2">&quot;gradient_checkpointing is not supported by onnxruntime-training. Please set gradient_checkpointing&quot;</span>
<span class="s2">&quot; to False.&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span>

<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">check_dependencies</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">ConfigBase</span><span class="p">,</span> <span class="n">is_qlora</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Check dependencies for the pass.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">use_ort_trainer</span><span class="p">:</span>
<span class="c1"># check for ort trainer dependencies</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">optimum.onnxruntime</span> <span class="kn">import</span> <span class="n">ORTTrainer</span> <span class="c1"># noqa: F401</span>
<span class="kn">from</span> <span class="nn">optimum.onnxruntime.utils</span> <span class="kn">import</span> <span class="n">is_onnxruntime_training_available</span>
<span class="kn">from</span> <span class="nn">torch_ort</span> <span class="kn">import</span> <span class="n">ORTModule</span> <span class="c1"># noqa: F401</span>

<span class="k">assert</span> <span class="n">is_onnxruntime_training_available</span><span class="p">(),</span> <span class="s2">&quot;onnxruntime-training is not available.&quot;</span>
<span class="k">except</span> <span class="p">(</span><span class="ne">ImportError</span><span class="p">,</span> <span class="ne">AssertionError</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span>
<span class="s2">&quot;Please install `olive-ai[optimum,ort-training]` or `onnxruntime-training optimum torch-ort` to use&quot;</span>
<span class="sa">f</span><span class="s2">&quot; </span><span class="si">{</span><span class="bp">cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> pass with use_ort_trainer=True.&quot;</span>
<span class="p">)</span> <span class="kn">from</span> <span class="kc">None</span>

<span class="c1"># check if model uses bfloat16</span>
<span class="n">uses_bf16</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_torch_dtype</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">torch_dtype</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span>
<span class="k">if</span> <span class="n">is_qlora</span> <span class="ow">and</span> <span class="n">config</span><span class="o">.</span><span class="n">compute_dtype</span><span class="p">:</span>
<span class="c1"># qlora compute dtype might be different from torch dtype</span>
<span class="n">uses_bf16</span> <span class="o">|=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_torch_dtype</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">compute_dtype</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span>

<span class="kn">from</span> <span class="nn">onnxruntime</span> <span class="kn">import</span> <span class="n">__version__</span> <span class="k">as</span> <span class="n">OrtVersion</span>

<span class="c1"># onnxruntime-training doesn&#39;t support bfloat16 fully until 1.17.0</span>
<span class="k">if</span> <span class="n">uses_bf16</span> <span class="ow">and</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">OrtVersion</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="s2">&quot;1.17.0&quot;</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Please install onnxruntime &gt;= 1.17.0 to use </span><span class="si">{</span><span class="bp">cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> with bfloat16 and&quot;</span>
<span class="s2">&quot; use_ort_trainer=True.&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">search_point</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;training_args&quot;</span><span class="p">,</span> <span class="p">{})</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;gradient_checkpointing&quot;</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s2">&quot;gradient_checkpointing is not supported by onnxruntime-training. Please set gradient_checkpointing&quot;</span>
<span class="s2">&quot; to False.&quot;</span>

<span class="k">assert</span> <span class="n">config</span><span class="o">.</span><span class="n">ortmodule_onnx_opset_version</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;ortmodule_onnx_opset_version must be a positive integer.&quot;</span>
<span class="c1"># ops such as Where only support bfloat16 from opset 16</span>
<span class="k">if</span> <span class="n">uses_bf16</span> <span class="ow">and</span> <span class="n">config</span><span class="o">.</span><span class="n">ortmodule_onnx_opset_version</span> <span class="o">&lt;</span> <span class="mi">16</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;ortmodule_onnx_opset_version is </span><span class="si">{</span><span class="n">config</span><span class="o">.</span><span class="n">ortmodule_onnx_opset_version</span><span class="si">}</span><span class="s2"> but training with bfloat16&quot;</span>
<span class="s2">&quot; might not work properly with opset versions &lt; 16&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&quot;ORTMODULE_ONNX_OPSET_VERSION&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">ortmodule_onnx_opset_version</span><span class="p">)</span>

<span class="c1"># bitsandbytes quantization only supported after transformers 4.30.0</span>
<span class="k">if</span> <span class="n">is_qlora</span> <span class="ow">and</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">transformers</span><span class="o">.</span><span class="n">__version__</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="s2">&quot;4.30.0&quot;</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Please install transformers &gt;= 4.30.0 to use </span><span class="si">{</span><span class="bp">cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> pass.&quot;</span><span class="p">)</span>

<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">collate_batch</span><span class="p">(</span><span class="n">batch</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">],</span> <span class="n">tokenizer</span><span class="p">:</span> <span class="n">transformers</span><span class="o">.</span><span class="n">PreTrainedTokenizer</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
Expand Down Expand Up @@ -699,6 +747,9 @@ <h1>Source code for olive.passes.pytorch.lora</h1><div class="highlight"><pre>
<span class="c1"># this will validate the config and convert to the correct types</span>
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_config_class</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span>

<span class="c1"># check dependencies</span>
<span class="bp">self</span><span class="o">.</span><span class="n">check_dependencies</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>

<span class="c1"># use default training args if not provided</span>
<span class="n">config</span><span class="o">.</span><span class="n">training_args</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">training_args</span> <span class="ow">or</span> <span class="n">HFTrainingArguments</span><span class="p">()</span>

Expand Down Expand Up @@ -780,14 +831,13 @@ <h1>Source code for olive.passes.pytorch.lora</h1><div class="highlight"><pre>
<span class="k">def</span> <span class="nf">_run_for_config</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">PyTorchModel</span><span class="p">,</span> <span class="n">data_root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">output_model_path</span><span class="p">:</span> <span class="nb">str</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">PyTorchModel</span><span class="p">:</span>
<span class="n">transformers_version</span> <span class="o">=</span> <span class="n">transformers</span><span class="o">.</span><span class="n">__version__</span>
<span class="k">if</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">transformers_version</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">version</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="s2">&quot;4.30.0&quot;</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;QLoRA pass only supports transformers &gt;= 4.30.0, but </span><span class="si">{</span><span class="n">transformers_version</span><span class="si">}</span><span class="s2"> is used.&quot;</span><span class="p">)</span>

<span class="c1"># convert config to pass config class</span>
<span class="c1"># this will validate the config and convert to the correct types</span>
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_config_class</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span>

<span class="c1"># check dependencies</span>
<span class="bp">self</span><span class="o">.</span><span class="n">check_dependencies</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">is_qlora</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>

<span class="c1"># MatMulBnb4 contrib op doesn&#39;t support double quantization so the trainer falls back to PythonOp</span>
<span class="c1"># which uses more memory and is slower</span>
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">use_ort_trainer</span> <span class="ow">and</span> <span class="n">config</span><span class="o">.</span><span class="n">double_quant</span><span class="p">:</span>
Expand Down
Loading

0 comments on commit e2c061f

Please sign in to comment.